Implement missing MTProto checks

This commit is contained in:
Dan 2021-12-15 13:18:13 +01:00
parent bf9e186414
commit cd027b8c1c
4 changed files with 71 additions and 4 deletions

View File

@ -19,9 +19,11 @@
from hashlib import sha256 from hashlib import sha256
from io import BytesIO from io import BytesIO
from os import urandom from os import urandom
from typing import Optional, List
from pyrogram.raw.core import Message, Long from pyrogram.raw.core import Message, Long
from . import aes from . import aes
from ..session.internals import MsgId
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple: def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
@ -49,13 +51,19 @@ def pack(message: Message, salt: int, session_id: bytes, auth_key: bytes, auth_k
return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv) return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv)
def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -> Message: def unpack(
b: BytesIO,
session_id: bytes,
auth_key: bytes,
auth_key_id: bytes,
stored_msg_ids: List[int]
) -> Optional[Message]:
assert b.read(8) == auth_key_id, b.getvalue() assert b.read(8) == auth_key_id, b.getvalue()
msg_key = b.read(16) msg_key = b.read(16)
aes_key, aes_iv = kdf(auth_key, msg_key, False) aes_key, aes_iv = kdf(auth_key, msg_key, False)
data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv)) data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv))
data.read(8) data.read(8) # Salt
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id # https://core.telegram.org/mtproto/security_guidelines#checking-session-id
assert data.read(8) == session_id assert data.read(8) == session_id
@ -75,11 +83,41 @@ def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -
raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}") raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}")
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
# 96 = 88 + 8 (incoming message) # 96 = 88 + 8 (incoming message)
assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4)
payload = data.read()
padding = payload[message.length:]
assert 12 <= len(padding) <= 1024
assert len(payload) % 4 == 0
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
assert message.msg_id % 2 != 0 assert message.msg_id % 2 != 0
if len(stored_msg_ids) > 200:
stored_msg_ids = stored_msg_ids[50:]
if stored_msg_ids:
# Ignored message: msg_id is lower than all of the stored values
if message.msg_id < stored_msg_ids[0]:
return None
# Ignored message: msg_id is equal to any of the stored values
if message.msg_id in stored_msg_ids:
return None
time_diff = (message.msg_id - MsgId()) / 2 ** 32
# Ignored message: msg_id belongs over 30 seconds in the future
if time_diff > 30:
return None
# Ignored message: msg_id belongs over 300 seconds in the past
if time_diff < -300:
return None
stored_msg_ids.append(message.msg_id)
return message return message

View File

@ -42,3 +42,12 @@ class FutureSalt(TLObject):
salt = Long.read(data) salt = Long.read(data)
return FutureSalt(valid_since, valid_until, salt) return FutureSalt(valid_since, valid_until, salt)
def write(self, *args: Any) -> bytes:
b = BytesIO()
b.write(Int(self.valid_since))
b.write(Int(self.valid_until))
b.write(Long(self.salt))
return b.getvalue()

View File

@ -45,3 +45,17 @@ class FutureSalts(TLObject):
salts = [FutureSalt.read(data) for _ in range(count)] salts = [FutureSalt.read(data) for _ in range(count)]
return FutureSalts(req_msg_id, now, salts) return FutureSalts(req_msg_id, now, salts)
def write(self, *args: Any) -> bytes:
b = BytesIO()
b.write(Long(self.req_msg_id))
b.write(Int(self.now))
count = len(self.salts)
b.write(Int(count))
for salt in self.salts:
b.write(salt.write())
return b.getvalue()

View File

@ -102,6 +102,8 @@ class Session:
self.results = {} self.results = {}
self.stored_msg_ids = []
self.ping_task = None self.ping_task = None
self.ping_task_event = asyncio.Event() self.ping_task_event = asyncio.Event()
@ -224,9 +226,13 @@ class Session:
BytesIO(packet), BytesIO(packet),
self.session_id, self.session_id,
self.auth_key, self.auth_key,
self.auth_key_id self.auth_key_id,
self.stored_msg_ids
) )
if data is None:
return
messages = ( messages = (
data.body.messages data.body.messages
if isinstance(data.body, MsgContainer) if isinstance(data.body, MsgContainer)