Simplify the error handling a bit

This commit is contained in:
Dan 2021-12-15 19:26:54 +01:00
parent c2a29c8c30
commit ed9c7e4694
2 changed files with 9 additions and 13 deletions

View File

@ -20,7 +20,7 @@ import bisect
from hashlib import sha256 from hashlib import sha256
from io import BytesIO from io import BytesIO
from os import urandom from os import urandom
from typing import List, Tuple from typing import List
from pyrogram.raw.core import Message, Long from pyrogram.raw.core import Message, Long
from . import aes from . import aes
@ -60,8 +60,8 @@ def unpack(
auth_key: bytes, auth_key: bytes,
auth_key_id: bytes, auth_key_id: bytes,
stored_msg_ids: List[int] stored_msg_ids: List[int]
) -> Tuple[Message, bool]: ) -> Message:
assert b.read(8) == auth_key_id, b.getvalue() assert b.read(8) == auth_key_id
msg_key = b.read(16) msg_key = b.read(16)
aes_key, aes_iv = kdf(auth_key, msg_key, False) aes_key, aes_iv = kdf(auth_key, msg_key, False)
@ -105,22 +105,22 @@ def unpack(
if stored_msg_ids: if stored_msg_ids:
# Ignored message: msg_id is lower than all of the stored values # Ignored message: msg_id is lower than all of the stored values
if message.msg_id < stored_msg_ids[0]: if message.msg_id < stored_msg_ids[0]:
return message, False assert False
# Ignored message: msg_id is equal to any of the stored values # Ignored message: msg_id is equal to any of the stored values
if message.msg_id in stored_msg_ids: if message.msg_id in stored_msg_ids:
return message, False assert False
time_diff = (message.msg_id - MsgId()) / 2 ** 32 time_diff = (message.msg_id - MsgId()) / 2 ** 32
# Ignored message: msg_id belongs over 30 seconds in the future # Ignored message: msg_id belongs over 30 seconds in the future
if time_diff > 30: if time_diff > 30:
return message, False assert False
# Ignored message: msg_id belongs over 300 seconds in the past # Ignored message: msg_id belongs over 300 seconds in the past
if time_diff < -300: if time_diff < -300:
return message, False assert False
bisect.insort(stored_msg_ids, message.msg_id) bisect.insort(stored_msg_ids, message.msg_id)
return message, True return message

View File

@ -221,7 +221,7 @@ class Session:
async def handle_packet(self, packet): async def handle_packet(self, packet):
try: try:
data, ok = await self.loop.run_in_executor( data = await self.loop.run_in_executor(
pyrogram.crypto_executor, pyrogram.crypto_executor,
mtproto.unpack, mtproto.unpack,
BytesIO(packet), BytesIO(packet),
@ -234,10 +234,6 @@ class Session:
self.connection.close() self.connection.close()
return return
if not ok:
self.connection.close()
return
messages = ( messages = (
data.body.messages data.body.messages
if isinstance(data.body, MsgContainer) if isinstance(data.body, MsgContainer)