mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-28 00:56:19 +00:00
Update session.py
This commit is contained in:
parent
7182a7cff7
commit
d298c62c6d
@ -32,7 +32,7 @@ from pyrogram.errors import (
|
|||||||
)
|
)
|
||||||
from pyrogram.raw.all import layer
|
from pyrogram.raw.all import layer
|
||||||
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
|
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
|
||||||
from .internals import MsgId, MsgFactory
|
from .internals import MsgFactory
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -85,9 +85,9 @@ class Session:
|
|||||||
self.ping_task = None
|
self.ping_task = None
|
||||||
self.ping_task_event = asyncio.Event()
|
self.ping_task_event = asyncio.Event()
|
||||||
|
|
||||||
self.network_task = None
|
self.recv_task = None
|
||||||
|
|
||||||
self.is_connected = asyncio.Event()
|
self.is_started = asyncio.Event()
|
||||||
|
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class Session:
|
|||||||
try:
|
try:
|
||||||
await self.connection.connect()
|
await self.connection.connect()
|
||||||
|
|
||||||
self.network_task = self.loop.create_task(self.network_worker())
|
self.recv_task = self.loop.create_task(self.recv_worker())
|
||||||
|
|
||||||
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
|
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
|
||||||
|
|
||||||
@ -128,14 +128,13 @@ class Session:
|
|||||||
|
|
||||||
self.ping_task = self.loop.create_task(self.ping_worker())
|
self.ping_task = self.loop.create_task(self.ping_worker())
|
||||||
|
|
||||||
log.info(f"Session initialized: Layer {layer}")
|
log.info("Session initialized: Layer %s", layer)
|
||||||
log.info(f"Device: {self.client.device_model} - {self.client.app_version}")
|
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
|
||||||
log.info(f"System: {self.client.system_version} ({self.client.lang_code.upper()})")
|
log.info("System: %s (%s)", self.client.system_version, self.client.lang_code)
|
||||||
|
|
||||||
except AuthKeyDuplicated as e:
|
except AuthKeyDuplicated as e:
|
||||||
await self.stop()
|
await self.stop()
|
||||||
raise e
|
raise e
|
||||||
except (OSError, TimeoutError, RPCError):
|
except (OSError, RPCError):
|
||||||
await self.stop()
|
await self.stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.stop()
|
await self.stop()
|
||||||
@ -143,12 +142,12 @@ class Session:
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
self.is_connected.set()
|
self.is_started.set()
|
||||||
|
|
||||||
log.info("Session started")
|
log.info("Session started")
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.is_connected.clear()
|
self.is_started.clear()
|
||||||
|
|
||||||
self.ping_task_event.set()
|
self.ping_task_event.set()
|
||||||
|
|
||||||
@ -159,17 +158,14 @@ class Session:
|
|||||||
|
|
||||||
await self.connection.close()
|
await self.connection.close()
|
||||||
|
|
||||||
if self.network_task:
|
if self.recv_task:
|
||||||
await self.network_task
|
await self.recv_task
|
||||||
|
|
||||||
for i in self.results.values():
|
|
||||||
i.event.set()
|
|
||||||
|
|
||||||
if not self.is_media and callable(self.client.disconnect_handler):
|
if not self.is_media and callable(self.client.disconnect_handler):
|
||||||
try:
|
try:
|
||||||
await self.client.disconnect_handler(self.client)
|
await self.client.disconnect_handler(self.client)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(e, exc_info=True)
|
log.exception(e)
|
||||||
|
|
||||||
log.info("Session stopped")
|
log.info("Session stopped")
|
||||||
|
|
||||||
@ -189,7 +185,7 @@ class Session:
|
|||||||
self.stored_msg_ids
|
self.stored_msg_ids
|
||||||
)
|
)
|
||||||
except SecurityCheckMismatch as e:
|
except SecurityCheckMismatch as e:
|
||||||
log.info(f"Discarding packet: {e}")
|
log.info("Discarding packet: %s", e)
|
||||||
await self.connection.close()
|
await self.connection.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -199,10 +195,7 @@ class Session:
|
|||||||
else [data]
|
else [data]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
|
log.debug("Received: %s", data)
|
||||||
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
|
|
||||||
log.debug("Received:")
|
|
||||||
log.debug(data)
|
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if msg.seq_no % 2 != 0:
|
if msg.seq_no % 2 != 0:
|
||||||
@ -235,11 +228,11 @@ class Session:
|
|||||||
self.results[msg_id].event.set()
|
self.results[msg_id].event.set()
|
||||||
|
|
||||||
if len(self.pending_acks) >= self.ACKS_THRESHOLD:
|
if len(self.pending_acks) >= self.ACKS_THRESHOLD:
|
||||||
log.debug(f"Send {len(self.pending_acks)} acks")
|
log.debug("Sending %s acks", len(self.pending_acks))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False)
|
await self.send(raw.types.MsgsAck(msg_ids=list(self.pending_acks)), False)
|
||||||
except (OSError, TimeoutError):
|
except OSError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.pending_acks.clear()
|
self.pending_acks.clear()
|
||||||
@ -261,12 +254,12 @@ class Session:
|
|||||||
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
|
ping_id=0, disconnect_delay=self.WAIT_TIMEOUT + 10
|
||||||
), False
|
), False
|
||||||
)
|
)
|
||||||
except (OSError, TimeoutError, RPCError):
|
except (OSError, RPCError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log.info("PingTask stopped")
|
log.info("PingTask stopped")
|
||||||
|
|
||||||
async def network_worker(self):
|
async def recv_worker(self):
|
||||||
log.info("NetworkTask started")
|
log.info("NetworkTask started")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -274,9 +267,9 @@ class Session:
|
|||||||
|
|
||||||
if packet is None or len(packet) == 4:
|
if packet is None or len(packet) == 4:
|
||||||
if packet:
|
if packet:
|
||||||
log.warning(f'Server sent "{Int.read(BytesIO(packet))}"')
|
log.warning('Server sent "%s"', Int.read(BytesIO(packet)))
|
||||||
|
|
||||||
if self.is_connected.is_set():
|
if self.is_started.is_set():
|
||||||
self.loop.create_task(self.restart())
|
self.loop.create_task(self.restart())
|
||||||
|
|
||||||
break
|
break
|
||||||
@ -289,13 +282,7 @@ class Session:
|
|||||||
message = self.msg_factory(data)
|
message = self.msg_factory(data)
|
||||||
msg_id = message.msg_id
|
msg_id = message.msg_id
|
||||||
|
|
||||||
if wait_response:
|
log.debug("Sent: %s", message)
|
||||||
self.results[msg_id] = Result()
|
|
||||||
|
|
||||||
# Call log.debug twice because calling it once by appending "data" to the previous string (i.e. f"Kind: {data}")
|
|
||||||
# will cause "data" to be evaluated as string every time instead of only when debug is actually enabled.
|
|
||||||
log.debug(f"Sent:")
|
|
||||||
log.debug(message)
|
|
||||||
|
|
||||||
payload = await self.loop.run_in_executor(
|
payload = await self.loop.run_in_executor(
|
||||||
pyrogram.crypto_executor,
|
pyrogram.crypto_executor,
|
||||||
@ -307,33 +294,34 @@ class Session:
|
|||||||
self.auth_key_id
|
self.auth_key_id
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
await self.connection.send(payload)
|
await self.connection.send(payload)
|
||||||
except OSError as e:
|
|
||||||
self.results.pop(msg_id, None)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
if wait_response:
|
if wait_response:
|
||||||
|
self.results[msg_id] = Result()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
|
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass
|
pass
|
||||||
finally:
|
|
||||||
result = self.results.pop(msg_id).value
|
result = self.results.pop(msg_id).value
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
raise TimeoutError
|
raise TimeoutError("Response timed out")
|
||||||
elif isinstance(result, raw.types.RpcError):
|
|
||||||
|
if isinstance(result, raw.types.RpcError):
|
||||||
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
|
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
|
||||||
data = data.query
|
data = data.query
|
||||||
|
|
||||||
RPCError.raise_it(result, type(data))
|
RPCError.raise_it(result, type(data))
|
||||||
elif isinstance(result, raw.types.BadMsgNotification):
|
|
||||||
|
if isinstance(result, raw.types.BadMsgNotification):
|
||||||
raise BadMsgNotification(result.error_code)
|
raise BadMsgNotification(result.error_code)
|
||||||
elif isinstance(result, raw.types.BadServerSalt):
|
|
||||||
|
if isinstance(result, raw.types.BadServerSalt):
|
||||||
self.salt = result.new_server_salt
|
self.salt = result.new_server_salt
|
||||||
return await self.send(data, wait_response, timeout)
|
return await self.send(data, wait_response, timeout)
|
||||||
else:
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def invoke(
|
async def invoke(
|
||||||
@ -344,7 +332,7 @@ class Session:
|
|||||||
sleep_threshold: float = SLEEP_THRESHOLD
|
sleep_threshold: float = SLEEP_THRESHOLD
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT)
|
await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -364,16 +352,19 @@ class Session:
|
|||||||
if amount > sleep_threshold >= 0:
|
if amount > sleep_threshold >= 0:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
log.warning(f'[{self.client.name}] Waiting for {amount} seconds before continuing '
|
log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")',
|
||||||
f'(required by "{query_name}")')
|
self.client.name, amount, query_name)
|
||||||
|
|
||||||
await asyncio.sleep(amount)
|
await asyncio.sleep(amount)
|
||||||
except (OSError, TimeoutError, InternalServerError, ServiceUnavailable) as e:
|
except (OSError, InternalServerError, ServiceUnavailable) as e:
|
||||||
if retries == 0:
|
if retries == 0:
|
||||||
raise e from None
|
raise e from None
|
||||||
|
|
||||||
(log.warning if retries < 2 else log.info)(
|
(log.warning if retries < 2 else log.info)(
|
||||||
f'[{Session.MAX_RETRIES - retries + 1}] Retrying "{query_name}" due to {str(e) or repr(e)}')
|
'[%s] Retrying "%s" due to: %s',
|
||||||
|
Session.MAX_RETRIES - retries + 1,
|
||||||
|
query_name, str(e) or repr(e)
|
||||||
|
)
|
||||||
|
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user