diff --git a/pyrogram/crypto/mtproto.py b/pyrogram/crypto/mtproto.py index 803db297..1eec7b7a 100644 --- a/pyrogram/crypto/mtproto.py +++ b/pyrogram/crypto/mtproto.py @@ -19,9 +19,11 @@ from hashlib import sha256 from io import BytesIO from os import urandom +from typing import Optional, List from pyrogram.raw.core import Message, Long from . import aes +from ..session.internals import MsgId def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple: @@ -49,13 +51,19 @@ def pack(message: Message, salt: int, session_id: bytes, auth_key: bytes, auth_k return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv) -def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -> Message: +def unpack( + b: BytesIO, + session_id: bytes, + auth_key: bytes, + auth_key_id: bytes, + stored_msg_ids: List[int] +) -> Optional[Message]: assert b.read(8) == auth_key_id, b.getvalue() msg_key = b.read(16) aes_key, aes_iv = kdf(auth_key, msg_key, False) data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv)) - data.read(8) + data.read(8) # Salt # https://core.telegram.org/mtproto/security_guidelines#checking-session-id assert data.read(8) == session_id @@ -75,11 +83,41 @@ def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) - raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}") # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key - # https://core.telegram.org/mtproto/security_guidelines#checking-message-length # 96 = 88 + 8 (incoming message) assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] + # https://core.telegram.org/mtproto/security_guidelines#checking-message-length + data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4) + payload = data.read() + padding = payload[message.length:] + assert 12 <= len(padding) <= 1024 + assert len(payload) % 4 == 0 + # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id assert message.msg_id % 2 != 0 + if len(stored_msg_ids) > 200: + stored_msg_ids = stored_msg_ids[50:] + + if stored_msg_ids: + # Ignored message: msg_id is lower than all of the stored values + if message.msg_id < stored_msg_ids[0]: + return None + + # Ignored message: msg_id is equal to any of the stored values + if message.msg_id in stored_msg_ids: + return None + + time_diff = (message.msg_id - MsgId()) / 2 ** 32 + + # Ignored message: msg_id belongs over 30 seconds in the future + if time_diff > 30: + return None + + # Ignored message: msg_id belongs over 300 seconds in the past + if time_diff < -300: + return None + + stored_msg_ids.append(message.msg_id) + return message diff --git a/pyrogram/raw/core/future_salt.py b/pyrogram/raw/core/future_salt.py index 85303d12..54a12963 100644 --- a/pyrogram/raw/core/future_salt.py +++ b/pyrogram/raw/core/future_salt.py @@ -42,3 +42,12 @@ class FutureSalt(TLObject): salt = Long.read(data) return FutureSalt(valid_since, valid_until, salt) + + def write(self, *args: Any) -> bytes: + b = BytesIO() + + b.write(Int(self.valid_since)) + b.write(Int(self.valid_until)) + b.write(Long(self.salt)) + + return b.getvalue() diff --git a/pyrogram/raw/core/future_salts.py b/pyrogram/raw/core/future_salts.py index faa4b741..9fa2f8e9 100644 --- a/pyrogram/raw/core/future_salts.py +++ b/pyrogram/raw/core/future_salts.py @@ -45,3 +45,17 @@ class FutureSalts(TLObject): salts = [FutureSalt.read(data) for _ in range(count)] return FutureSalts(req_msg_id, now, salts) + + def write(self, *args: Any) -> bytes: + b = BytesIO() + + b.write(Long(self.req_msg_id)) + b.write(Int(self.now)) + + count = len(self.salts) + b.write(Int(count)) + + for salt in self.salts: + b.write(salt.write()) + + return b.getvalue() diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 721586a0..39fe605c 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -102,6 +102,8 @@ class Session: self.results = {} + self.stored_msg_ids = [] + self.ping_task = None self.ping_task_event = asyncio.Event() @@ -224,9 +226,13 @@ class Session: BytesIO(packet), self.session_id, self.auth_key, - self.auth_key_id + self.auth_key_id, + self.stored_msg_ids ) + if data is None: + return + messages = ( data.body.messages if isinstance(data.body, MsgContainer)