Apply security checks to each message in the container

This commit is contained in:
Dan 2022-12-29 23:33:58 +01:00
parent 7ee47b220d
commit cf1e31c413
2 changed files with 40 additions and 45 deletions

View File

@ -16,18 +16,13 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import bisect
from hashlib import sha256 from hashlib import sha256
from io import BytesIO from io import BytesIO
from os import urandom from os import urandom
from typing import List
from pyrogram.errors import SecurityCheckMismatch from pyrogram.errors import SecurityCheckMismatch
from pyrogram.raw.core import Message, Long from pyrogram.raw.core import Message, Long
from . import aes from . import aes
from ..session.internals import MsgId
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple: def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
@ -59,8 +54,7 @@ def unpack(
b: BytesIO, b: BytesIO,
session_id: bytes, session_id: bytes,
auth_key: bytes, auth_key: bytes,
auth_key_id: bytes, auth_key_id: bytes
stored_msg_ids: List[int]
) -> Message: ) -> Message:
SecurityCheckMismatch.check(b.read(8) == auth_key_id, "b.read(8) == auth_key_id") SecurityCheckMismatch.check(b.read(8) == auth_key_id, "b.read(8) == auth_key_id")
@ -103,26 +97,4 @@ def unpack(
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
SecurityCheckMismatch.check(message.msg_id % 2 != 0, "message.msg_id % 2 != 0") SecurityCheckMismatch.check(message.msg_id % 2 != 0, "message.msg_id % 2 != 0")
if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE:
del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2]
if stored_msg_ids:
if message.msg_id < stored_msg_ids[0]:
raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
if message.msg_id in stored_msg_ids:
raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
time_diff = (message.msg_id - MsgId()) / 2 ** 32
if time_diff > 30:
raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
"Most likely the client time has to be synchronized.")
if time_diff < -300:
raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
"Most likely the client time has to be synchronized.")
bisect.insort(stored_msg_ids, message.msg_id)
return message return message

View File

@ -17,6 +17,7 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
import bisect
import logging import logging
import os import os
from hashlib import sha1 from hashlib import sha1
@ -32,7 +33,7 @@ from pyrogram.errors import (
) )
from pyrogram.raw.all import layer from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
from .internals import MsgFactory from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -50,6 +51,7 @@ class Session:
MAX_RETRIES = 10 MAX_RETRIES = 10
ACKS_THRESHOLD = 10 ACKS_THRESHOLD = 10
PING_INTERVAL = 5 PING_INTERVAL = 5
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
def __init__( def __init__(
self, self,
@ -176,20 +178,14 @@ class Session:
await self.start() await self.start()
async def handle_packet(self, packet): async def handle_packet(self, packet):
try: data = await self.loop.run_in_executor(
data = await self.loop.run_in_executor( pyrogram.crypto_executor,
pyrogram.crypto_executor, mtproto.unpack,
mtproto.unpack, BytesIO(packet),
BytesIO(packet), self.session_id,
self.session_id, self.auth_key,
self.auth_key, self.auth_key_id
self.auth_key_id, )
self.stored_msg_ids
)
except SecurityCheckMismatch as e:
log.warning("Discarding packet: %s", e)
await self.connection.close()
return
messages = ( messages = (
data.body.messages data.body.messages
@ -206,6 +202,33 @@ class Session:
else: else:
self.pending_acks.add(msg.msg_id) self.pending_acks.add(msg.msg_id)
try:
if len(self.stored_msg_ids) > Session.STORED_MSG_IDS_MAX_SIZE:
del self.stored_msg_ids[:Session.STORED_MSG_IDS_MAX_SIZE // 2]
if self.stored_msg_ids:
if msg.msg_id < self.stored_msg_ids[0]:
raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
if msg.msg_id in self.stored_msg_ids:
raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
time_diff = (msg.msg_id - MsgId()) / 2 ** 32
if time_diff > 30:
raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
"Most likely the client time has to be synchronized.")
if time_diff < -300:
raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
"Most likely the client time has to be synchronized.")
except SecurityCheckMismatch as e:
log.warning("Discarding packet: %s", e)
await self.connection.close()
return
else:
bisect.insort(self.stored_msg_ids, msg.msg_id)
if isinstance(msg.body, (raw.types.MsgDetailedInfo, raw.types.MsgNewDetailedInfo)): if isinstance(msg.body, (raw.types.MsgDetailedInfo, raw.types.MsgNewDetailedInfo)):
self.pending_acks.add(msg.body.answer_msg_id) self.pending_acks.add(msg.body.answer_msg_id)
continue continue
@ -323,7 +346,7 @@ class Session:
RPCError.raise_it(result, type(data)) RPCError.raise_it(result, type(data))
if isinstance(result, raw.types.BadMsgNotification): if isinstance(result, raw.types.BadMsgNotification):
raise BadMsgNotification(result.error_code) log.warning("%s: %s", BadMsgNotification.__name__, BadMsgNotification(result.error_code))
if isinstance(result, raw.types.BadServerSalt): if isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt self.salt = result.new_server_salt