Update session.py

This commit is contained in:
Dan 2022-12-26 16:34:49 +01:00
parent 7182a7cff7
commit d298c62c6d

View File

@ -32,7 +32,7 @@ from pyrogram.errors import (
)
from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
from .internals import MsgId, MsgFactory
from .internals import MsgFactory
log = logging.getLogger(__name__)
@ -85,9 +85,9 @@ class Session:
self.ping_task = None
self.ping_task_event = asyncio.Event()
self.network_task = None
self.recv_task = None
self.is_connected = asyncio.Event()
self.is_started = asyncio.Event()
self.loop = asyncio.get_event_loop()
@ -104,7 +104,7 @@ class Session:
try:
await self.connection.connect()
self.network_task = self.loop.create_task(self.network_worker())
self.recv_task = self.loop.create_task(self.recv_worker())
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
@ -128,14 +128,13 @@ class Session:
self.ping_task = self.loop.create_task(self.ping_worker())
log.info(f"Session initialized: Layer {layer}")
log.info(f"Device: {self.client.device_model} - {self.client.app_version}")
log.info(f"System: {self.client.system_version} ({self.client.lang_code.upper()})")
log.info("Session initialized: Layer %s", layer)
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
log.info("System: %s (%s)", self.client.system_version, self.client.lang_code)
except AuthKeyDuplicated as e:
await self.stop()
raise e
except (OSError, TimeoutError, RPCError):
except (OSError, RPCError):
await self.stop()
except Exception as e:
await self.stop()
@ -143,12 +142,12 @@ class Session:
else:
break
self.is_connected.set()
self.is_started.set()
log.info("Session started")
async def stop(self):
self.is_connected.clear()
self.is_started.clear()
self.ping_task_event.set()
@ -159,17 +158,14 @@ class Session:
await self.connection.close()
if self.network_task:
await self.network_task
for i in self.results.values():
i.event.set()
if self.recv_task:
await self.recv_task
if not self.is_media and callable(self.client.disconnect_handler):
try:
await self.client.disconnect_handler(self.client)
except Exception as e:
log.error(e, exc_info=True)
log.exception(e)
log.info("Session stopped")
@ -189,7 +185,7 @@ class Session:
self.stored_msg_ids
)
except SecurityCheckMismatch as e:
log.info(f"Discarding packet: {e}")
log.info("Discarding packet: %s", e)
await self.connection.close()
return
@ -199,10 +195,7 @@ class Session:
else [data]
)
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
log.debug("Received:")
log.debug(data)
log.debug("Received: %s", data)
for msg in messages:
if msg.seq_no % 2 != 0:
@ -235,11 +228,11 @@ class Session:
self.results[msg_id].event.set()
if len(self.pending_acks) >= self.ACKS_THRESHOLD:
log.debug(f"Send {len(self.pending_acks)} acks")
log.debug("Sending %s acks", len(self.pending_acks))
try:
await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False)
except (OSError, TimeoutError):
except OSError:
pass
else:
self.pending_acks.clear()
@ -261,12 +254,12 @@ class Session:
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
), False
)
except (OSError, TimeoutError, RPCError):
except (OSError, RPCError):
pass
log.info("PingTask stopped")
async def network_worker(self):
async def recv_worker(self):
log.info("NetworkTask started")
while True:
@ -274,9 +267,9 @@ class Session:
if packet is None or len(packet) == 4:
if packet:
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
log.warning('Server sent "%s"', Int.read(BytesIO(packet)))
if self.is_connected.is_set():
if self.is_started.is_set():
self.loop.create_task(self.restart())
break
@ -289,13 +282,7 @@ class Session:
message = self.msg_factory(data)
msg_id = message.msg_id
if wait_response:
self.results[msg_id] = Result()
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
log.debug(f"Sent:")
log.debug(message)
log.debug("Sent: %s", message)
payload = await self.loop.run_in_executor(
pyrogram.crypto_executor,
@ -307,34 +294,35 @@ class Session:
self.auth_key_id
)
try:
await self.connection.send(payload)
except OSError as e:
self.results.pop(msg_id, None)
raise e
await self.connection.send(payload)
if wait_response:
self.results[msg_id] = Result()
try:
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
except asyncio.TimeoutError:
pass
finally:
result = self.results.pop(msg_id).value
result = self.results.pop(msg_id).value
if result is None:
raise TimeoutError
elif isinstance(result, raw.types.RpcError):
raise TimeoutError("Response timed out")
if isinstance(result, raw.types.RpcError):
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
data = data.query
RPCError.raise_it(result, type(data))
elif isinstance(result, raw.types.BadMsgNotification):
if isinstance(result, raw.types.BadMsgNotification):
raise BadMsgNotification(result.error_code)
elif isinstance(result, raw.types.BadServerSalt):
if isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt
return await self.send(data, wait_response, timeout)
else:
return result
return result
async def invoke(
self,
@ -344,7 +332,7 @@ class Session:
sleep_threshold: float = SLEEP_THRESHOLD
):
try:
await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT)
await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT)
except asyncio.TimeoutError:
pass
@ -364,16 +352,19 @@ class Session:
if amount > sleep_threshold >= 0:
raise
log.warning(f'[{self.client.name}] Waiting for {amount} seconds before continuing '
f'(required by "{query_name}")')
log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")',
self.client.name, amount, query_name)
await asyncio.sleep(amount)
except (OSError, TimeoutError, InternalServerError, ServiceUnavailable) as e:
except (OSError, InternalServerError, ServiceUnavailable) as e:
if retries == 0:
raise e from None
(log.warning if retries < 2 else log.info)(
f'[{Session.MAX_RETRIES - retries + 1}] Retrying "{query_name}" due to {str(e) or repr(e)}')
'[%s] Retrying "%s" due to: %s',
Session.MAX_RETRIES - retries + 1,
query_name, str(e) or repr(e)
)
await asyncio.sleep(0.5)