Apply security checks to each message in the container

This commit is contained in:
Dan 2022-12-29 23:33:58 +01:00
parent 7ee47b220d
commit cf1e31c413
2 changed files with 40 additions and 45 deletions

View File

@ -16,18 +16,13 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import bisect
from hashlib import sha256
from io import BytesIO
from os import urandom
from typing import List
from pyrogram.errors import SecurityCheckMismatch
from pyrogram.raw.core import Message, Long
from . import aes
from ..session.internals import MsgId
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
@ -59,8 +54,7 @@ def unpack(
b: BytesIO,
session_id: bytes,
auth_key: bytes,
auth_key_id: bytes,
stored_msg_ids: List[int]
auth_key_id: bytes
) -> Message:
SecurityCheckMismatch.check(b.read(8) == auth_key_id, "b.read(8) == auth_key_id")
@ -103,26 +97,4 @@ def unpack(
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
SecurityCheckMismatch.check(message.msg_id % 2 != 0, "message.msg_id % 2 != 0")
if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE:
del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2]
if stored_msg_ids:
if message.msg_id < stored_msg_ids[0]:
raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
if message.msg_id in stored_msg_ids:
raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
time_diff = (message.msg_id - MsgId()) / 2 ** 32
if time_diff > 30:
raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
"Most likely the client time has to be synchronized.")
if time_diff < -300:
raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
"Most likely the client time has to be synchronized.")
bisect.insort(stored_msg_ids, message.msg_id)
return message

View File

@ -17,6 +17,7 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import bisect
import logging
import os
from hashlib import sha1
@ -32,7 +33,7 @@ from pyrogram.errors import (
)
from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
from .internals import MsgFactory
from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__)
@ -50,6 +51,7 @@ class Session:
MAX_RETRIES = 10
ACKS_THRESHOLD = 10
PING_INTERVAL = 5
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
def __init__(
self,
@ -176,20 +178,14 @@ class Session:
await self.start()
async def handle_packet(self, packet):
try:
data = await self.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.unpack,
BytesIO(packet),
self.session_id,
self.auth_key,
self.auth_key_id,
self.stored_msg_ids
self.auth_key_id
)
except SecurityCheckMismatch as e:
log.warning("Discarding packet: %s", e)
await self.connection.close()
return
messages = (
data.body.messages
@ -206,6 +202,33 @@ class Session:
else:
self.pending_acks.add(msg.msg_id)
try:
if len(self.stored_msg_ids) > Session.STORED_MSG_IDS_MAX_SIZE:
del self.stored_msg_ids[:Session.STORED_MSG_IDS_MAX_SIZE // 2]
if self.stored_msg_ids:
if msg.msg_id < self.stored_msg_ids[0]:
raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
if msg.msg_id in self.stored_msg_ids:
raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
time_diff = (msg.msg_id - MsgId()) / 2 ** 32
if time_diff > 30:
raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
"Most likely the client time has to be synchronized.")
if time_diff < -300:
raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
"Most likely the client time has to be synchronized.")
except SecurityCheckMismatch as e:
log.warning("Discarding packet: %s", e)
await self.connection.close()
return
else:
bisect.insort(self.stored_msg_ids, msg.msg_id)
if isinstance(msg.body, (raw.types.MsgDetailedInfo, raw.types.MsgNewDetailedInfo)):
self.pending_acks.add(msg.body.answer_msg_id)
continue
@ -323,7 +346,7 @@ class Session:
RPCError.raise_it(result, type(data))
if isinstance(result, raw.types.BadMsgNotification):
raise BadMsgNotification(result.error_code)
log.warning("%s: %s", BadMsgNotification.__name__, BadMsgNotification(result.error_code))
if isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt