From 2a1af2b8e9afb54cc8f96bffb95317e52fbafcd3 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Wed, 15 Dec 2021 16:02:39 +0100 Subject: [PATCH] Close and reestablish the TCP connection in case of mismatch --- pyrogram/crypto/mtproto.py | 14 +++++++------- pyrogram/session/session.py | 25 +++++++++++++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/pyrogram/crypto/mtproto.py b/pyrogram/crypto/mtproto.py index 81cf56f4..a22693fe 100644 --- a/pyrogram/crypto/mtproto.py +++ b/pyrogram/crypto/mtproto.py @@ -20,7 +20,7 @@ import bisect from hashlib import sha256 from io import BytesIO from os import urandom -from typing import Optional, List +from typing import List, Tuple from pyrogram.raw.core import Message, Long from . import aes @@ -58,7 +58,7 @@ def unpack( auth_key: bytes, auth_key_id: bytes, stored_msg_ids: List[int] -) -> Optional[Message]: +) -> Tuple[Message, bool]: assert b.read(8) == auth_key_id, b.getvalue() msg_key = b.read(16) @@ -103,22 +103,22 @@ def unpack( if stored_msg_ids: # Ignored message: msg_id is lower than all of the stored values if message.msg_id < stored_msg_ids[0]: - return None + return message, False # Ignored message: msg_id is equal to any of the stored values if message.msg_id in stored_msg_ids: - return None + return message, False time_diff = (message.msg_id - MsgId()) / 2 ** 32 # Ignored message: msg_id belongs over 30 seconds in the future if time_diff > 30: - return None + return message, False # Ignored message: msg_id belongs over 300 seconds in the past if time_diff < -300: - return None + return message, False bisect.insort(stored_msg_ids, message.msg_id) - return message + return message, True diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 39fe605c..b3875f3c 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -220,17 +220,22 @@ class Session: await self.start() async def handle_packet(self, packet): - data = await self.loop.run_in_executor( - pyrogram.crypto_executor, - mtproto.unpack, - BytesIO(packet), - self.session_id, - self.auth_key, - self.auth_key_id, - self.stored_msg_ids - ) + try: + data, ok = await self.loop.run_in_executor( + pyrogram.crypto_executor, + mtproto.unpack, + BytesIO(packet), + self.session_id, + self.auth_key, + self.auth_key_id, + self.stored_msg_ids + ) + except AssertionError: + self.connection.close() + return - if data is None: + if not ok: + self.connection.close() return messages = (