Clean up on download's stop_transmission and return None

This commit is contained in:
Dan 2022-05-04 09:04:25 +02:00
parent 97b6c32c7f
commit 956e5c1a4f
2 changed files with 22 additions and 17 deletions

View File

@ -23,7 +23,7 @@ __copyright__ = "Copyright (C) 2017-present Dan <https://github.com/delivrance>"
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
class StopTransmission(StopAsyncIteration): class StopTransmission(Exception):
pass pass

View File

@ -727,23 +727,27 @@ class Client(Methods):
async def handle_download(self, packet): async def handle_download(self, packet):
file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet file_id, directory, file_name, in_memory, file_size, progress, progress_args = packet
file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb", delete=False) file = BytesIO() if in_memory else tempfile.NamedTemporaryFile("wb", delete=False)
async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args): try:
file.write(chunk) async for chunk in self.get_file(file_id, file_size, 0, 0, progress, progress_args):
file.write(chunk)
except pyrogram.StopTransmission:
if not in_memory:
file.close()
os.remove(file.name)
if file and not in_memory: return None
file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name))) else:
os.makedirs(directory, exist_ok=True) if in_memory:
file.close() file.name = file_name
shutil.move(file.name, file_path) return file
else:
return file_path file_path = os.path.abspath(re.sub("\\\\", "/", os.path.join(directory, file_name)))
os.makedirs(directory, exist_ok=True)
if file and in_memory: file.close()
file.name = file_name shutil.move(file.name, file_path)
return file return file_path
async def get_file( async def get_file(
self, self,
@ -970,9 +974,10 @@ class Client(Methods):
break break
except Exception as e: except Exception as e:
raise e raise e
except pyrogram.StopTransmission:
raise
except Exception as e: except Exception as e:
if not isinstance(e, pyrogram.StopTransmission): log.error(e, exc_info=True)
log.error(e, exc_info=True)
def guess_mime_type(self, filename: str) -> Optional[str]: def guess_mime_type(self, filename: str) -> Optional[str]:
return self.mimetypes.guess_type(filename)[0] return self.mimetypes.guess_type(filename)[0]