From d298c62c6dea4c6fc0a5d8384a123481a21df2b0 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Mon, 26 Dec 2022 16:34:49 +0100 Subject: [PATCH] Update session.py --- pyrogram/session/session.py | 97 +++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 53 deletions(-) diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index fe13c743..2a22c2bd 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -32,7 +32,7 @@ from pyrogram.errors import ( ) from pyrogram.raw.all import layer from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts -from .internals import MsgId, MsgFactory +from .internals import MsgFactory log = logging.getLogger(__name__) @@ -85,9 +85,9 @@ class Session: self.ping_task = None self.ping_task_event = asyncio.Event() - self.network_task = None + self.recv_task = None - self.is_connected = asyncio.Event() + self.is_started = asyncio.Event() self.loop = asyncio.get_event_loop() @@ -104,7 +104,7 @@ class Session: try: await self.connection.connect() - self.network_task = self.loop.create_task(self.network_worker()) + self.recv_task = self.loop.create_task(self.recv_worker()) await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT) @@ -128,14 +128,13 @@ class Session: self.ping_task = self.loop.create_task(self.ping_worker()) - log.info(f"Session initialized: Layer {layer}") - log.info(f"Device: {self.client.device_model} - {self.client.app_version}") - log.info(f"System: {self.client.system_version} ({self.client.lang_code.upper()})") - + log.info("Session initialized: Layer %s", layer) + log.info("Device: %s - %s", self.client.device_model, self.client.app_version) + log.info("System: %s (%s)", self.client.system_version, self.client.lang_code) except AuthKeyDuplicated as e: await self.stop() raise e - except (OSError, TimeoutError, RPCError): + except (OSError, RPCError): await self.stop() except Exception as e: await self.stop() @@ -143,12 +142,12 @@ class Session: else: break - self.is_connected.set() + self.is_started.set() log.info("Session started") async def stop(self): - self.is_connected.clear() + self.is_started.clear() self.ping_task_event.set() @@ -159,17 +158,14 @@ class Session: await self.connection.close() - if self.network_task: - await self.network_task - - for i in self.results.values(): - i.event.set() + if self.recv_task: + await self.recv_task if not self.is_media and callable(self.client.disconnect_handler): try: await self.client.disconnect_handler(self.client) except Exception as e: - log.error(e, exc_info=True) + log.exception(e) log.info("Session stopped") @@ -189,7 +185,7 @@ class Session: self.stored_msg_ids ) except SecurityCheckMismatch as e: - log.info(f"Discarding packet: {e}") + log.info("Discarding packet: %s", e) await self.connection.close() return @@ -199,10 +195,7 @@ class Session: else [data] ) - # Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}") - # will cause "data" to be evaluated as string every time instead of only when debug is actually enabled. - log.debug("Received:") - log.debug(data) + log.debug("Received: %s", data) for msg in messages: if msg.seq_no % 2 != 0: @@ -235,11 +228,11 @@ class Session: self.results[msg_id].event.set() if len(self.pending_acks) >= self.ACKS_THRESHOLD: - log.debug(f"Send {len(self.pending_acks)} acks") + log.debug("Sending %s acks", len(self.pending_acks)) try: await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False) - except (OSError, TimeoutError): + except OSError: pass else: self.pending_acks.clear() @@ -261,12 +254,12 @@ class Session: ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10 ), False ) - except (OSError, TimeoutError, RPCError): + except (OSError, RPCError): pass log.info("PingTask stopped") - async def network_worker(self): + async def recv_worker(self): log.info("NetworkTask started") while True: @@ -274,9 +267,9 @@ class Session: if packet is None or len(packet) == 4: if packet: - log.warning(f'Server sent "{Int.read(BytesIO(packet))}"') + log.warning('Server sent "%s"', Int.read(BytesIO(packet))) - if self.is_connected.is_set(): + if self.is_started.is_set(): self.loop.create_task(self.restart()) break @@ -289,13 +282,7 @@ class Session: message = self.msg_factory(data) msg_id = message.msg_id - if wait_response: - self.results[msg_id] = Result() - - # Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}") - # will cause "data" to be evaluated as string every time instead of only when debug is actually enabled. - log.debug(f"Sent:") - log.debug(message) + log.debug("Sent: %s", message) payload = await self.loop.run_in_executor( pyrogram.crypto_executor, @@ -307,34 +294,35 @@ class Session: self.auth_key_id ) - try: - await self.connection.send(payload) - except OSError as e: - self.results.pop(msg_id, None) - raise e + await self.connection.send(payload) if wait_response: + self.results[msg_id] = Result() + try: await asyncio.wait_for(self.results[msg_id].event.wait(), timeout) except asyncio.TimeoutError: pass - finally: - result = self.results.pop(msg_id).value + + result = self.results.pop(msg_id).value if result is None: - raise TimeoutError - elif isinstance(result, raw.types.RpcError): + raise TimeoutError("Response timed out") + + if isinstance(result, raw.types.RpcError): if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)): data = data.query RPCError.raise_it(result, type(data)) - elif isinstance(result, raw.types.BadMsgNotification): + + if isinstance(result, raw.types.BadMsgNotification): raise BadMsgNotification(result.error_code) - elif isinstance(result, raw.types.BadServerSalt): + + if isinstance(result, raw.types.BadServerSalt): self.salt = result.new_server_salt return await self.send(data, wait_response, timeout) - else: - return result + + return result async def invoke( self, @@ -344,7 +332,7 @@ class Session: sleep_threshold: float = SLEEP_THRESHOLD ): try: - await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT) + await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT) except asyncio.TimeoutError: pass @@ -364,16 +352,19 @@ class Session: if amount > sleep_threshold >= 0: raise - log.warning(f'[{self.client.name}] Waiting for {amount} seconds before continuing ' - f'(required by "{query_name}")') + log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")', + self.client.name, amount, query_name) await asyncio.sleep(amount) - except (OSError, TimeoutError, InternalServerError, ServiceUnavailable) as e: + except (OSError, InternalServerError, ServiceUnavailable) as e: if retries == 0: raise e from None (log.warning if retries < 2 else log.info)( - f'[{Session.MAX_RETRIES - retries + 1}] Retrying "{query_name}" due to {str(e) or repr(e)}') + '[%s] Retrying "%s" due to: %s', + Session.MAX_RETRIES - retries + 1, + query_name, str(e) or repr(e) + ) await asyncio.sleep(0.5)