Clean up on download's stop_transmission and return None

This commit is contained in:
Dan 2022-05-04 09:04:25 +02:00
parent 97b6c32c7f
commit 956e5c1a4f
2 changed files with 22 additions and 17 deletions

View File

@ -23,7 +23,7 @@ __copyright__ = "Copyright (C) 2017-present Dan <https://github.com/delivrance>"
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
class StopTransmission(StopAsyncIteration): class StopTransmission(Exception):
pass pass

View File

@ -727,24 +727,28 @@ class Client(Methods):
async def handle_download(self, packet): async def handle_download(self, packet):
file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb", delete=False) file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb", delete=False)
try:
async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args): async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args):
file.write(chunk) file.write(chunk)
except pyrogram.StopTransmission:
if not in_memory:
file.close()
os.remove(file.name)
if file and not in_memory: return None
else:
if in_memory:
file.name = file_name
return file
else:
file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
file.close() file.close()
shutil.move(file.name, file_path) shutil.move(file.name, file_path)
return file_path return file_path
if file and in_memory:
file.name = file_name
return file
async def get_file( async def get_file(
self, self,
file_id: FileId, file_id: FileId,
@ -970,8 +974,9 @@ class Client(Methods):
break break
except Exception as e: except Exception as e:
raise e raise e
except pyrogram.StopTransmission:
raise
except Exception as e: except Exception as e:
if not isinstance(e, pyrogram.StopTransmission):
log.error(e, exc_info=True) log.error(e, exc_info=True)
def guess_mime_type(self, filename: str) -> Optional[str]: def guess_mime_type(self, filename: str) -> Optional[str]: