mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-18 13:34:54 +00:00
Make save_file more efficient
This commit is contained in:
parent
aa800c3ebc
commit
d28f795aca
@ -1070,6 +1070,15 @@ class Client(Methods, BaseClient):
|
|||||||
file_part: int = 0,
|
file_part: int = 0,
|
||||||
progress: callable = None,
|
progress: callable = None,
|
||||||
progress_args: tuple = ()):
|
progress_args: tuple = ()):
|
||||||
|
async def worker():
|
||||||
|
while True:
|
||||||
|
data = await queue.get()
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
await asyncio.ensure_future(session.send(data))
|
||||||
|
|
||||||
part_size = 512 * 1024
|
part_size = 512 * 1024
|
||||||
file_size = os.path.getsize(path)
|
file_size = os.path.getsize(path)
|
||||||
file_total_parts = int(math.ceil(file_size / part_size))
|
file_total_parts = int(math.ceil(file_size / part_size))
|
||||||
@ -1077,11 +1086,13 @@ class Client(Methods, BaseClient):
|
|||||||
is_missing_part = True if file_id is not None else False
|
is_missing_part = True if file_id is not None else False
|
||||||
file_id = file_id or self.rnd_id()
|
file_id = file_id or self.rnd_id()
|
||||||
md5_sum = md5() if not is_big and not is_missing_part else None
|
md5_sum = md5() if not is_big and not is_missing_part else None
|
||||||
|
|
||||||
session = Session(self, self.dc_id, self.auth_key, is_media=True)
|
session = Session(self, self.dc_id, self.auth_key, is_media=True)
|
||||||
await session.start()
|
workers = [asyncio.ensure_future(worker()) for _ in range(4)]
|
||||||
|
queue = asyncio.Queue(16)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
await session.start()
|
||||||
|
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
f.seek(part_size * file_part)
|
f.seek(part_size * file_part)
|
||||||
|
|
||||||
@ -1107,7 +1118,7 @@ class Client(Methods, BaseClient):
|
|||||||
bytes=chunk
|
bytes=chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
assert await session.send(rpc), "Couldn't upload file"
|
await queue.put(rpc)
|
||||||
|
|
||||||
if is_missing_part:
|
if is_missing_part:
|
||||||
return
|
return
|
||||||
@ -1137,6 +1148,10 @@ class Client(Methods, BaseClient):
|
|||||||
md5_checksum=md5_sum
|
md5_checksum=md5_sum
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
for _ in workers:
|
||||||
|
await queue.put(None)
|
||||||
|
|
||||||
|
await asyncio.gather(*workers)
|
||||||
await session.stop()
|
await session.stop()
|
||||||
|
|
||||||
async def get_file(self,
|
async def get_file(self,
|
||||||
|
Loading…
Reference in New Issue
Block a user