diff --git a/pyrogram/crypto/mtproto.py b/pyrogram/crypto/mtproto.py index e147c22a..6d1521a4 100644 --- a/pyrogram/crypto/mtproto.py +++ b/pyrogram/crypto/mtproto.py @@ -16,18 +16,13 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -import bisect from hashlib import sha256 from io import BytesIO from os import urandom -from typing import List from pyrogram.errors import SecurityCheckMismatch from pyrogram.raw.core import Message, Long from . import aes -from ..session.internals import MsgId - -STORED_MSG_IDS_MAX_SIZE = 1000 * 2 def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple: @@ -59,8 +54,7 @@ def unpack( b: BytesIO, session_id: bytes, auth_key: bytes, - auth_key_id: bytes, - stored_msg_ids: List[int] + auth_key_id: bytes ) -> Message: SecurityCheckMismatch.check(b.read(8) == auth_key_id, "b.read(8) == auth_key_id") @@ -103,26 +97,4 @@ def unpack( # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id SecurityCheckMismatch.check(message.msg_id % 2 != 0, "message.msg_id % 2 != 0") - if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE: - del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2] - - if stored_msg_ids: - if message.msg_id < stored_msg_ids[0]: - raise SecurityCheckMismatch("The msg_id is lower than all the stored values") - - if message.msg_id in stored_msg_ids: - raise SecurityCheckMismatch("The msg_id is equal to any of the stored values") - - time_diff = (message.msg_id - MsgId()) / 2 ** 32 - - if time_diff > 30: - raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. " - "Most likely the client time has to be synchronized.") - - if time_diff < -300: - raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. " - "Most likely the client time has to be synchronized.") - - bisect.insort(stored_msg_ids, message.msg_id) - return message diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 3ce96a8f..54814906 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -17,6 +17,7 @@ # along with Pyrogram. If not, see . import asyncio +import bisect import logging import os from hashlib import sha1 @@ -32,7 +33,7 @@ from pyrogram.errors import ( ) from pyrogram.raw.all import layer from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts -from .internals import MsgFactory +from .internals import MsgId, MsgFactory log = logging.getLogger(__name__) @@ -50,6 +51,7 @@ class Session: MAX_RETRIES = 10 ACKS_THRESHOLD = 10 PING_INTERVAL = 5 + STORED_MSG_IDS_MAX_SIZE = 1000 * 2 def __init__( self, @@ -176,20 +178,14 @@ class Session: await self.start() async def handle_packet(self, packet): - try: - data = await self.loop.run_in_executor( - pyrogram.crypto_executor, - mtproto.unpack, - BytesIO(packet), - self.session_id, - self.auth_key, - self.auth_key_id, - self.stored_msg_ids - ) - except SecurityCheckMismatch as e: - log.warning("Discarding packet: %s", e) - await self.connection.close() - return + data = await self.loop.run_in_executor( + pyrogram.crypto_executor, + mtproto.unpack, + BytesIO(packet), + self.session_id, + self.auth_key, + self.auth_key_id + ) messages = ( data.body.messages @@ -206,6 +202,33 @@ class Session: else: self.pending_acks.add(msg.msg_id) + try: + if len(self.stored_msg_ids) > Session.STORED_MSG_IDS_MAX_SIZE: + del self.stored_msg_ids[:Session.STORED_MSG_IDS_MAX_SIZE // 2] + + if self.stored_msg_ids: + if msg.msg_id < self.stored_msg_ids[0]: + raise SecurityCheckMismatch("The msg_id is lower than all the stored values") + + if msg.msg_id in self.stored_msg_ids: + raise SecurityCheckMismatch("The msg_id is equal to any of the stored values") + + time_diff = (msg.msg_id - MsgId()) / 2 ** 32 + + if time_diff > 30: + raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. " + "Most likely the client time has to be synchronized.") + + if time_diff < -300: + raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. " + "Most likely the client time has to be synchronized.") + except SecurityCheckMismatch as e: + log.warning("Discarding packet: %s", e) + await self.connection.close() + return + else: + bisect.insort(self.stored_msg_ids, msg.msg_id) + if isinstance(msg.body, (raw.types.MsgDetailedInfo, raw.types.MsgNewDetailedInfo)): self.pending_acks.add(msg.body.answer_msg_id) continue @@ -323,7 +346,7 @@ class Session: RPCError.raise_it(result, type(data)) if isinstance(result, raw.types.BadMsgNotification): - raise BadMsgNotification(result.error_code) + log.warning("%s: %s", BadMsgNotification.__name__, BadMsgNotification(result.error_code)) if isinstance(result, raw.types.BadServerSalt): self.salt = result.new_server_salt