diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 6c6f0006..48be996b 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -18,6 +18,7 @@ import asyncio import bisect +import contextlib import logging import os from hashlib import sha1 @@ -97,6 +98,7 @@ class Session: self.recv_task = None self.is_started = asyncio.Event() + self.restart_event = asyncio.Event() self.loop = asyncio.get_event_loop() @@ -165,14 +167,16 @@ class Session: self.ping_task_event.set() if self.ping_task is not None: - await self.ping_task + with contextlib.suppress(Exception): + await self.ping_task self.ping_task_event.clear() await self.connection.close() if self.recv_task: - await self.recv_task + with contextlib.suppress(Exception): + await self.recv_task if not self.is_media and callable(self.client.disconnect_handler): try: @@ -183,8 +187,10 @@ class Session: log.info("Session stopped") async def restart(self): + self.restart_event.set() await self.stop() await self.start() + self.restart_event.clear() async def handle_packet(self, packet): try: @@ -424,6 +430,17 @@ class Session: query_name, str(e) or repr(e) ) + # restart was never being called after Exception block + if not self.restart_event.is_set(): + self.loop.create_task(self.restart()) + else: + # multiple Exceptions can be raised in a row, so we need to wait for the restart to finish + try: + await asyncio.wait_for(self.restart_event.wait(), self.WAIT_TIMEOUT) + except asyncio.TimeoutError: + if self.restart_event.is_set(): + self.restart_event.clear() + await asyncio.sleep(0.5) return await self.invoke(query, retries - 1, timeout)