Update session.py

This commit is contained in:
Dan 2022-12-26 16:34:49 +01:00
parent 7182a7cff7
commit d298c62c6d

View File

@ -32,7 +32,7 @@ from pyrogram.errors import (
) )
from pyrogram.raw.all import layer from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
from .internals import MsgId, MsgFactory from .internals import MsgFactory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -85,9 +85,9 @@ class Session:
self.ping_task = None self.ping_task = None
self.ping_task_event = asyncio.Event() self.ping_task_event = asyncio.Event()
self.network_task = None self.recv_task = None
self.is_connected = asyncio.Event() self.is_started = asyncio.Event()
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
@ -104,7 +104,7 @@ class Session:
try: try:
await self.connection.connect() await self.connection.connect()
self.network_task = self.loop.create_task(self.network_worker()) self.recv_task = self.loop.create_task(self.recv_worker())
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT) await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
@ -128,14 +128,13 @@ class Session:
self.ping_task = self.loop.create_task(self.ping_worker()) self.ping_task = self.loop.create_task(self.ping_worker())
log.info(f"Session initialized: Layer {layer}") log.info("Session initialized: Layer %s", layer)
log.info(f"Device: {self.client.device_model} - {self.client.app_version}") log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
log.info(f"System: {self.client.system_version} ({self.client.lang_code.upper()})") log.info("System: %s (%s)", self.client.system_version, self.client.lang_code)
except AuthKeyDuplicated as e: except AuthKeyDuplicated as e:
await self.stop() await self.stop()
raise e raise e
except (OSError, TimeoutError, RPCError): except (OSError, RPCError):
await self.stop() await self.stop()
except Exception as e: except Exception as e:
await self.stop() await self.stop()
@ -143,12 +142,12 @@ class Session:
else: else:
break break
self.is_connected.set() self.is_started.set()
log.info("Session started") log.info("Session started")
async def stop(self): async def stop(self):
self.is_connected.clear() self.is_started.clear()
self.ping_task_event.set() self.ping_task_event.set()
@ -159,17 +158,14 @@ class Session:
await self.connection.close() await self.connection.close()
if self.network_task: if self.recv_task:
await self.network_task await self.recv_task
for i in self.results.values():
i.event.set()
if not self.is_media and callable(self.client.disconnect_handler): if not self.is_media and callable(self.client.disconnect_handler):
try: try:
await self.client.disconnect_handler(self.client) await self.client.disconnect_handler(self.client)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.exception(e)
log.info("Session stopped") log.info("Session stopped")
@ -189,7 +185,7 @@ class Session:
self.stored_msg_ids self.stored_msg_ids
) )
except SecurityCheckMismatch as e: except SecurityCheckMismatch as e:
log.info(f"Discarding packet: {e}") log.info("Discarding packet: %s", e)
await self.connection.close() await self.connection.close()
return return
@ -199,10 +195,7 @@ class Session:
else [data] else [data]
) )
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}") log.debug("Received: %s", data)
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
log.debug("Received:")
log.debug(data)
for msg in messages: for msg in messages:
if msg.seq_no % 2 != 0: if msg.seq_no % 2 != 0:
@ -235,11 +228,11 @@ class Session:
self.results[msg_id].event.set() self.results[msg_id].event.set()
if len(self.pending_acks) >= self.ACKS_THRESHOLD: if len(self.pending_acks) >= self.ACKS_THRESHOLD:
log.debug(f"Send {len(self.pending_acks)} acks") log.debug("Sending %s acks", len(self.pending_acks))
try: try:
await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False) await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False)
except (OSError, TimeoutError): except OSError:
pass pass
else: else:
self.pending_acks.clear() self.pending_acks.clear()
@ -261,12 +254,12 @@ class Session:
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10 ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
), False ), False
) )
except (OSError, TimeoutError, RPCError): except (OSError, RPCError):
pass pass
log.info("PingTask stopped") log.info("PingTask stopped")
async def network_worker(self): async def recv_worker(self):
log.info("NetworkTask started") log.info("NetworkTask started")
while True: while True:
@ -274,9 +267,9 @@ class Session:
if packet is None or len(packet) == 4: if packet is None or len(packet) == 4:
if packet: if packet:
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"') log.warning('Server sent "%s"', Int.read(BytesIO(packet)))
if self.is_connected.is_set(): if self.is_started.is_set():
self.loop.create_task(self.restart()) self.loop.create_task(self.restart())
break break
@ -289,13 +282,7 @@ class Session:
message = self.msg_factory(data) message = self.msg_factory(data)
msg_id = message.msg_id msg_id = message.msg_id
if wait_response: log.debug("Sent: %s", message)
self.results[msg_id] = Result()
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
log.debug(f"Sent:")
log.debug(message)
payload = await self.loop.run_in_executor( payload = await self.loop.run_in_executor(
pyrogram.crypto_executor, pyrogram.crypto_executor,
@ -307,33 +294,34 @@ class Session:
self.auth_key_id self.auth_key_id
) )
try:
await self.connection.send(payload) await self.connection.send(payload)
except OSError as e:
self.results.pop(msg_id, None)
raise e
if wait_response: if wait_response:
self.results[msg_id] = Result()
try: try:
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout) await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
finally:
result = self.results.pop(msg_id).value result = self.results.pop(msg_id).value
if result is None: if result is None:
raise TimeoutError raise TimeoutError("Response timed out")
elif isinstance(result, raw.types.RpcError):
if isinstance(result, raw.types.RpcError):
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)): if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
data = data.query data = data.query
RPCError.raise_it(result, type(data)) RPCError.raise_it(result, type(data))
elif isinstance(result, raw.types.BadMsgNotification):
if isinstance(result, raw.types.BadMsgNotification):
raise BadMsgNotification(result.error_code) raise BadMsgNotification(result.error_code)
elif isinstance(result, raw.types.BadServerSalt):
if isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt self.salt = result.new_server_salt
return await self.send(data, wait_response, timeout) return await self.send(data, wait_response, timeout)
else:
return result return result
async def invoke( async def invoke(
@ -344,7 +332,7 @@ class Session:
sleep_threshold: float = SLEEP_THRESHOLD sleep_threshold: float = SLEEP_THRESHOLD
): ):
try: try:
await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT) await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
@ -364,16 +352,19 @@ class Session:
if amount > sleep_threshold >= 0: if amount > sleep_threshold >= 0:
raise raise
log.warning(f'[{self.client.name}] Waiting for {amount} seconds before continuing ' log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")',
f'(required by "{query_name}")') self.client.name, amount, query_name)
await asyncio.sleep(amount) await asyncio.sleep(amount)
except (OSError, TimeoutError, InternalServerError, ServiceUnavailable) as e: except (OSError, InternalServerError, ServiceUnavailable) as e:
if retries == 0: if retries == 0:
raise e from None raise e from None
(log.warning if retries < 2 else log.info)( (log.warning if retries < 2 else log.info)(
f'[{Session.MAX_RETRIES - retries + 1}] Retrying "{query_name}" due to {str(e) or repr(e)}') '[%s] Retrying "%s" due to: %s',
Session.MAX_RETRIES - retries + 1,
query_name, str(e) or repr(e)
)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)