diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 3f4f631b..1c61e960 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -184,7 +184,7 @@ class Client(Methods, BaseClient): plugins: dict = None, no_updates: bool = None, takeout: bool = None, - sleep_threshold: int = 60 + sleep_threshold: int = Session.SLEEP_THRESHOLD ): super().__init__() @@ -1410,31 +1410,13 @@ class Client(Methods, BaseClient): if not self.is_connected: raise ConnectionError("Client has not been started yet") - # Some raw methods that expect a query as argument are used here. - # Keep the original request query because is needed. - unwrapped_data = data - if self.no_updates: data = functions.InvokeWithoutUpdates(query=data) if self.takeout_id: data = functions.InvokeWithTakeout(takeout_id=self.takeout_id, query=data) - while True: - try: - r = await self.session.send(data, retries, timeout) - except FloodWait as e: - amount = e.x - - if amount > self.sleep_threshold: - raise - - log.warning('[{}] Sleeping for {}s (required by "{}")'.format( - self.session_name, amount, ".".join(unwrapped_data.QUALNAME.split(".")[1:]))) - - await asyncio.sleep(amount) - else: - break + r = await self.session.send(data, retries, timeout, self.sleep_threshold) self.fetch_peers(getattr(r, "users", [])) self.fetch_peers(getattr(r, "chats", [])) diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 62399d69..3900f1d8 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -29,7 +29,7 @@ from pyrogram.api.all import layer from pyrogram.api.core import TLObject, MsgContainer, Int, Long, FutureSalt, FutureSalts from pyrogram.connection import Connection from pyrogram.crypto import MTProto -from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated +from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated, FloodWait from .internals import MsgId, MsgFactory log = logging.getLogger(__name__) @@ -46,6 +46,7 @@ class Session: NET_WORKERS = 1 START_TIMEOUT = 1 WAIT_TIMEOUT = 15 + SLEEP_THRESHOLD = 60 MAX_RETRIES = 5 ACKS_THRESHOLD = 8 PING_INTERVAL = 5 @@ -402,22 +403,47 @@ class Session: else: return result - async def send(self, data: TLObject, retries: int = MAX_RETRIES, timeout: float = WAIT_TIMEOUT): + async def send( + self, + data: TLObject, + retries: int = MAX_RETRIES, + timeout: float = WAIT_TIMEOUT, + sleep_threshold: float = SLEEP_THRESHOLD + ): try: await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT) except asyncio.TimeoutError: pass - try: - return await self._send(data, timeout=timeout) - except (OSError, TimeoutError, InternalServerError) as e: - if retries == 0: - raise e from None + if isinstance(data, (functions.InvokeWithoutUpdates, functions.InvokeWithTakeout)): + query = data.query + else: + query = data - (log.warning if retries < 2 else log.info)( - "[{}] Retrying {} due to {}".format( - Session.MAX_RETRIES - retries + 1, - data.QUALNAME, e)) + query = ".".join(query.QUALNAME.split(".")[1:]) - await asyncio.sleep(0.5) - return await self.send(data, retries - 1, timeout) + while True: + try: + return await self._send(data, timeout=timeout) + except FloodWait as e: + amount = e.x + + if amount > sleep_threshold: + raise + + log.warning('[{}] Sleeping for {}s (required by "{}")'.format( + self.client.session_name, amount, query)) + + await asyncio.sleep(amount) + except (OSError, TimeoutError, InternalServerError) as e: + if retries == 0: + raise e from None + + (log.warning if retries < 2 else log.info)( + '[{}] Retrying "{}" due to {}'.format( + Session.MAX_RETRIES - retries + 1, + query, e)) + + await asyncio.sleep(0.5) + + return await self.send(data, retries - 1, timeout)