From cf1e31c413dc6b5ad03b4b78a73fc256b3e537c7 Mon Sep 17 00:00:00 2001
From: Dan <14043624+delivrance@users.noreply.github.com>
Date: Thu, 29 Dec 2022 23:33:58 +0100
Subject: [PATCH] Apply security checks to each message in the container
---
pyrogram/crypto/mtproto.py | 30 +-------------------
pyrogram/session/session.py | 55 ++++++++++++++++++++++++++-----------
2 files changed, 40 insertions(+), 45 deletions(-)
diff --git a/pyrogram/crypto/mtproto.py b/pyrogram/crypto/mtproto.py
index e147c22a..6d1521a4 100644
--- a/pyrogram/crypto/mtproto.py
+++ b/pyrogram/crypto/mtproto.py
@@ -16,18 +16,13 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
-import bisect
from hashlib import sha256
from io import BytesIO
from os import urandom
-from typing import List
from pyrogram.errors import SecurityCheckMismatch
from pyrogram.raw.core import Message, Long
from . import aes
-from ..session.internals import MsgId
-
-STORED_MSG_IDS_MAX_SIZE = 1000 * 2
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
@@ -59,8 +54,7 @@ def unpack(
b: BytesIO,
session_id: bytes,
auth_key: bytes,
- auth_key_id: bytes,
- stored_msg_ids: List[int]
+ auth_key_id: bytes
) -> Message:
SecurityCheckMismatch.check(b.read(8) == auth_key_id, "b.read(8) == auth_key_id")
@@ -103,26 +97,4 @@ def unpack(
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
SecurityCheckMismatch.check(message.msg_id % 2 != 0, "message.msg_id % 2 != 0")
- if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE:
- del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2]
-
- if stored_msg_ids:
- if message.msg_id < stored_msg_ids[0]:
- raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
-
- if message.msg_id in stored_msg_ids:
- raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
-
- time_diff = (message.msg_id - MsgId()) / 2 ** 32
-
- if time_diff > 30:
- raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
- "Most likely the client time has to be synchronized.")
-
- if time_diff < -300:
- raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
- "Most likely the client time has to be synchronized.")
-
- bisect.insort(stored_msg_ids, message.msg_id)
-
return message
diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py
index 3ce96a8f..54814906 100644
--- a/pyrogram/session/session.py
+++ b/pyrogram/session/session.py
@@ -17,6 +17,7 @@
# along with Pyrogram. If not, see .
import asyncio
+import bisect
import logging
import os
from hashlib import sha1
@@ -32,7 +33,7 @@ from pyrogram.errors import (
)
from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
-from .internals import MsgFactory
+from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__)
@@ -50,6 +51,7 @@ class Session:
MAX_RETRIES = 10
ACKS_THRESHOLD = 10
PING_INTERVAL = 5
+ STORED_MSG_IDS_MAX_SIZE = 1000 * 2
def __init__(
self,
@@ -176,20 +178,14 @@ class Session:
await self.start()
async def handle_packet(self, packet):
- try:
- data = await self.loop.run_in_executor(
- pyrogram.crypto_executor,
- mtproto.unpack,
- BytesIO(packet),
- self.session_id,
- self.auth_key,
- self.auth_key_id,
- self.stored_msg_ids
- )
- except SecurityCheckMismatch as e:
- log.warning("Discarding packet: %s", e)
- await self.connection.close()
- return
+ data = await self.loop.run_in_executor(
+ pyrogram.crypto_executor,
+ mtproto.unpack,
+ BytesIO(packet),
+ self.session_id,
+ self.auth_key,
+ self.auth_key_id
+ )
messages = (
data.body.messages
@@ -206,6 +202,33 @@ class Session:
else:
self.pending_acks.add(msg.msg_id)
+ try:
+ if len(self.stored_msg_ids) > Session.STORED_MSG_IDS_MAX_SIZE:
+ del self.stored_msg_ids[:Session.STORED_MSG_IDS_MAX_SIZE // 2]
+
+ if self.stored_msg_ids:
+ if msg.msg_id < self.stored_msg_ids[0]:
+ raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
+
+ if msg.msg_id in self.stored_msg_ids:
+ raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
+
+ time_diff = (msg.msg_id - MsgId()) / 2 ** 32
+
+ if time_diff > 30:
+ raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
+ "Most likely the client time has to be synchronized.")
+
+ if time_diff < -300:
+ raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
+ "Most likely the client time has to be synchronized.")
+ except SecurityCheckMismatch as e:
+ log.warning("Discarding packet: %s", e)
+ await self.connection.close()
+ return
+ else:
+ bisect.insort(self.stored_msg_ids, msg.msg_id)
+
if isinstance(msg.body, (raw.types.MsgDetailedInfo, raw.types.MsgNewDetailedInfo)):
self.pending_acks.add(msg.body.answer_msg_id)
continue
@@ -323,7 +346,7 @@ class Session:
RPCError.raise_it(result, type(data))
if isinstance(result, raw.types.BadMsgNotification):
- raise BadMsgNotification(result.error_code)
+ log.warning("%s: %s", BadMsgNotification.__name__, BadMsgNotification(result.error_code))
if isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt