Prevent connection to dc every time in get_file

This commit is contained in:
KurimuzonAkuma 2023-10-11 21:30:11 +03:00
parent efac17198b
commit 747dd11fb4

View File

@ -44,7 +44,7 @@ from pyrogram.errors import CDNFileHashMismatch
from pyrogram.errors import ( from pyrogram.errors import (
SessionPasswordNeeded, SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate, VolumeLocNotFound, ChannelPrivate,
BadRequest BadRequest, AuthBytesInvalid
) )
from pyrogram.handlers.handler import Handler from pyrogram.handlers.handler import Handler
from pyrogram.methods import Methods from pyrogram.methods import Methods
@ -862,7 +862,10 @@ class Client(Methods):
dc_id = file_id.dc_id dc_id = file_id.dc_id
session = Session( try:
session = self.media_sessions.get(dc_id)
if not session:
session = self.media_sessions[dc_id] = Session(
self, dc_id, self, dc_id,
await Auth(self, dc_id, await self.storage.test_mode()).create() await Auth(self, dc_id, await self.storage.test_mode()).create()
if dc_id != await self.storage.dc_id() if dc_id != await self.storage.dc_id()
@ -870,23 +873,29 @@ class Client(Methods):
await self.storage.test_mode(), await self.storage.test_mode(),
is_media=True is_media=True
) )
try:
await session.start() await session.start()
if dc_id != await self.storage.dc_id(): if dc_id != await self.storage.dc_id():
for _ in range(3):
exported_auth = await self.invoke( exported_auth = await self.invoke(
raw.functions.auth.ExportAuthorization( raw.functions.auth.ExportAuthorization(
dc_id=dc_id dc_id=dc_id
) )
) )
try:
await session.invoke( await session.invoke(
raw.functions.auth.ImportAuthorization( raw.functions.auth.ImportAuthorization(
id=exported_auth.id, id=exported_auth.id,
bytes=exported_auth.bytes bytes=exported_auth.bytes
) )
) )
except AuthBytesInvalid:
continue
else:
break
else:
raise AuthBytesInvalid
r = await session.invoke( r = await session.invoke(
raw.functions.upload.GetFile( raw.functions.upload.GetFile(
@ -1019,8 +1028,6 @@ class Client(Methods):
raise raise
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
finally:
await session.stop()
def guess_mime_type(self, filename: str) -> Optional[str]: def guess_mime_type(self, filename: str) -> Optional[str]:
return self.mimetypes.guess_type(filename)[0] return self.mimetypes.guess_type(filename)[0]