mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-30 17:43:32 +00:00
Update session.py
This commit is contained in:
parent
7182a7cff7
commit
d298c62c6d
@ -32,7 +32,7 @@ from pyrogram.errors import (
|
||||
)
|
||||
from pyrogram.raw.all import layer
|
||||
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
|
||||
from .internals import MsgId, MsgFactory
|
||||
from .internals import MsgFactory
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -85,9 +85,9 @@ class Session:
|
||||
self.ping_task = None
|
||||
self.ping_task_event = asyncio.Event()
|
||||
|
||||
self.network_task = None
|
||||
self.recv_task = None
|
||||
|
||||
self.is_connected = asyncio.Event()
|
||||
self.is_started = asyncio.Event()
|
||||
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
@ -104,7 +104,7 @@ class Session:
|
||||
try:
|
||||
await self.connection.connect()
|
||||
|
||||
self.network_task = self.loop.create_task(self.network_worker())
|
||||
self.recv_task = self.loop.create_task(self.recv_worker())
|
||||
|
||||
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
|
||||
|
||||
@ -128,14 +128,13 @@ class Session:
|
||||
|
||||
self.ping_task = self.loop.create_task(self.ping_worker())
|
||||
|
||||
log.info(f"Session initialized: Layer {layer}")
|
||||
log.info(f"Device: {self.client.device_model} - {self.client.app_version}")
|
||||
log.info(f"System: {self.client.system_version} ({self.client.lang_code.upper()})")
|
||||
|
||||
log.info("Session initialized: Layer %s", layer)
|
||||
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
|
||||
log.info("System: %s (%s)", self.client.system_version, self.client.lang_code)
|
||||
except AuthKeyDuplicated as e:
|
||||
await self.stop()
|
||||
raise e
|
||||
except (OSError, TimeoutError, RPCError):
|
||||
except (OSError, RPCError):
|
||||
await self.stop()
|
||||
except Exception as e:
|
||||
await self.stop()
|
||||
@ -143,12 +142,12 @@ class Session:
|
||||
else:
|
||||
break
|
||||
|
||||
self.is_connected.set()
|
||||
self.is_started.set()
|
||||
|
||||
log.info("Session started")
|
||||
|
||||
async def stop(self):
|
||||
self.is_connected.clear()
|
||||
self.is_started.clear()
|
||||
|
||||
self.ping_task_event.set()
|
||||
|
||||
@ -159,17 +158,14 @@ class Session:
|
||||
|
||||
await self.connection.close()
|
||||
|
||||
if self.network_task:
|
||||
await self.network_task
|
||||
|
||||
for i in self.results.values():
|
||||
i.event.set()
|
||||
if self.recv_task:
|
||||
await self.recv_task
|
||||
|
||||
if not self.is_media and callable(self.client.disconnect_handler):
|
||||
try:
|
||||
await self.client.disconnect_handler(self.client)
|
||||
except Exception as e:
|
||||
log.error(e, exc_info=True)
|
||||
log.exception(e)
|
||||
|
||||
log.info("Session stopped")
|
||||
|
||||
@ -189,7 +185,7 @@ class Session:
|
||||
self.stored_msg_ids
|
||||
)
|
||||
except SecurityCheckMismatch as e:
|
||||
log.info(f"Discarding packet: {e}")
|
||||
log.info("Discarding packet: %s", e)
|
||||
await self.connection.close()
|
||||
return
|
||||
|
||||
@ -199,10 +195,7 @@ class Session:
|
||||
else [data]
|
||||
)
|
||||
|
||||
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
|
||||
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
|
||||
log.debug("Received:")
|
||||
log.debug(data)
|
||||
log.debug("Received: %s", data)
|
||||
|
||||
for msg in messages:
|
||||
if msg.seq_no % 2 != 0:
|
||||
@ -235,11 +228,11 @@ class Session:
|
||||
self.results[msg_id].event.set()
|
||||
|
||||
if len(self.pending_acks) >= self.ACKS_THRESHOLD:
|
||||
log.debug(f"Send {len(self.pending_acks)} acks")
|
||||
log.debug("Sending %s acks", len(self.pending_acks))
|
||||
|
||||
try:
|
||||
await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False)
|
||||
except (OSError, TimeoutError):
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
self.pending_acks.clear()
|
||||
@ -261,12 +254,12 @@ class Session:
|
||||
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
|
||||
), False
|
||||
)
|
||||
except (OSError, TimeoutError, RPCError):
|
||||
except (OSError, RPCError):
|
||||
pass
|
||||
|
||||
log.info("PingTask stopped")
|
||||
|
||||
async def network_worker(self):
|
||||
async def recv_worker(self):
|
||||
log.info("NetworkTask started")
|
||||
|
||||
while True:
|
||||
@ -274,9 +267,9 @@ class Session:
|
||||
|
||||
if packet is None or len(packet) == 4:
|
||||
if packet:
|
||||
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
|
||||
log.warning('Server sent "%s"', Int.read(BytesIO(packet)))
|
||||
|
||||
if self.is_connected.is_set():
|
||||
if self.is_started.is_set():
|
||||
self.loop.create_task(self.restart())
|
||||
|
||||
break
|
||||
@ -289,13 +282,7 @@ class Session:
|
||||
message = self.msg_factory(data)
|
||||
msg_id = message.msg_id
|
||||
|
||||
if wait_response:
|
||||
self.results[msg_id] = Result()
|
||||
|
||||
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
|
||||
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
|
||||
log.debug(f"Sent:")
|
||||
log.debug(message)
|
||||
log.debug("Sent: %s", message)
|
||||
|
||||
payload = await self.loop.run_in_executor(
|
||||
pyrogram.crypto_executor,
|
||||
@ -307,33 +294,34 @@ class Session:
|
||||
self.auth_key_id
|
||||
)
|
||||
|
||||
try:
|
||||
await self.connection.send(payload)
|
||||
except OSError as e:
|
||||
self.results.pop(msg_id, None)
|
||||
raise e
|
||||
|
||||
if wait_response:
|
||||
self.results[msg_id] = Result()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
|
||||
result = self.results.pop(msg_id).value
|
||||
|
||||
if result is None:
|
||||
raise TimeoutError
|
||||
elif isinstance(result, raw.types.RpcError):
|
||||
raise TimeoutError("Response timed out")
|
||||
|
||||
if isinstance(result, raw.types.RpcError):
|
||||
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
|
||||
data = data.query
|
||||
|
||||
RPCError.raise_it(result, type(data))
|
||||
elif isinstance(result, raw.types.BadMsgNotification):
|
||||
|
||||
if isinstance(result, raw.types.BadMsgNotification):
|
||||
raise BadMsgNotification(result.error_code)
|
||||
elif isinstance(result, raw.types.BadServerSalt):
|
||||
|
||||
if isinstance(result, raw.types.BadServerSalt):
|
||||
self.salt = result.new_server_salt
|
||||
return await self.send(data, wait_response, timeout)
|
||||
else:
|
||||
|
||||
return result
|
||||
|
||||
async def invoke(
|
||||
@ -344,7 +332,7 @@ class Session:
|
||||
sleep_threshold: float = SLEEP_THRESHOLD
|
||||
):
|
||||
try:
|
||||
await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT)
|
||||
await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
@ -364,16 +352,19 @@ class Session:
|
||||
if amount > sleep_threshold >= 0:
|
||||
raise
|
||||
|
||||
log.warning(f'[{self.client.name}] Waiting for {amount} seconds before continuing '
|
||||
f'(required by "{query_name}")')
|
||||
log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")',
|
||||
self.client.name, amount, query_name)
|
||||
|
||||
await asyncio.sleep(amount)
|
||||
except (OSError, TimeoutError, InternalServerError, ServiceUnavailable) as e:
|
||||
except (OSError, InternalServerError, ServiceUnavailable) as e:
|
||||
if retries == 0:
|
||||
raise e from None
|
||||
|
||||
(log.warning if retries < 2 else log.info)(
|
||||
f'[{Session.MAX_RETRIES - retries + 1}] Retrying "{query_name}" due to {str(e) or repr(e)}')
|
||||
'[%s] Retrying "%s" due to: %s',
|
||||
Session.MAX_RETRIES - retries + 1,
|
||||
query_name, str(e) or repr(e)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user