mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-23 23:34:28 +00:00
Implement missing MTProto checks
This commit is contained in:
parent
bf9e186414
commit
cd027b8c1c
@ -19,9 +19,11 @@
|
||||
from hashlib import sha256
|
||||
from io import BytesIO
|
||||
from os import urandom
|
||||
from typing import Optional, List
|
||||
|
||||
from pyrogram.raw.core import Message, Long
|
||||
from . import aes
|
||||
from ..session.internals import MsgId
|
||||
|
||||
|
||||
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
|
||||
@ -49,13 +51,19 @@ def pack(message: Message, salt: int, session_id: bytes, auth_key: bytes, auth_k
|
||||
return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv)
|
||||
|
||||
|
||||
def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -> Message:
|
||||
def unpack(
|
||||
b: BytesIO,
|
||||
session_id: bytes,
|
||||
auth_key: bytes,
|
||||
auth_key_id: bytes,
|
||||
stored_msg_ids: List[int]
|
||||
) -> Optional[Message]:
|
||||
assert b.read(8) == auth_key_id, b.getvalue()
|
||||
|
||||
msg_key = b.read(16)
|
||||
aes_key, aes_iv = kdf(auth_key, msg_key, False)
|
||||
data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv))
|
||||
data.read(8)
|
||||
data.read(8) # Salt
|
||||
|
||||
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id
|
||||
assert data.read(8) == session_id
|
||||
@ -75,11 +83,41 @@ def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -
|
||||
raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}")
|
||||
|
||||
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
|
||||
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
|
||||
# 96 = 88 + 8 (incoming message)
|
||||
assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]
|
||||
|
||||
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
|
||||
data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4)
|
||||
payload = data.read()
|
||||
padding = payload[message.length:]
|
||||
assert 12 <= len(padding) <= 1024
|
||||
assert len(payload) % 4 == 0
|
||||
|
||||
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
|
||||
assert message.msg_id % 2 != 0
|
||||
|
||||
if len(stored_msg_ids) > 200:
|
||||
stored_msg_ids = stored_msg_ids[50:]
|
||||
|
||||
if stored_msg_ids:
|
||||
# Ignored message: msg_id is lower than all of the stored values
|
||||
if message.msg_id < stored_msg_ids[0]:
|
||||
return None
|
||||
|
||||
# Ignored message: msg_id is equal to any of the stored values
|
||||
if message.msg_id in stored_msg_ids:
|
||||
return None
|
||||
|
||||
time_diff = (message.msg_id - MsgId()) / 2 ** 32
|
||||
|
||||
# Ignored message: msg_id belongs over 30 seconds in the future
|
||||
if time_diff > 30:
|
||||
return None
|
||||
|
||||
# Ignored message: msg_id belongs over 300 seconds in the past
|
||||
if time_diff < -300:
|
||||
return None
|
||||
|
||||
stored_msg_ids.append(message.msg_id)
|
||||
|
||||
return message
|
||||
|
@ -42,3 +42,12 @@ class FutureSalt(TLObject):
|
||||
salt = Long.read(data)
|
||||
|
||||
return FutureSalt(valid_since, valid_until, salt)
|
||||
|
||||
def write(self, *args: Any) -> bytes:
|
||||
b = BytesIO()
|
||||
|
||||
b.write(Int(self.valid_since))
|
||||
b.write(Int(self.valid_until))
|
||||
b.write(Long(self.salt))
|
||||
|
||||
return b.getvalue()
|
||||
|
@ -45,3 +45,17 @@ class FutureSalts(TLObject):
|
||||
salts = [FutureSalt.read(data) for _ in range(count)]
|
||||
|
||||
return FutureSalts(req_msg_id, now, salts)
|
||||
|
||||
def write(self, *args: Any) -> bytes:
|
||||
b = BytesIO()
|
||||
|
||||
b.write(Long(self.req_msg_id))
|
||||
b.write(Int(self.now))
|
||||
|
||||
count = len(self.salts)
|
||||
b.write(Int(count))
|
||||
|
||||
for salt in self.salts:
|
||||
b.write(salt.write())
|
||||
|
||||
return b.getvalue()
|
||||
|
@ -102,6 +102,8 @@ class Session:
|
||||
|
||||
self.results = {}
|
||||
|
||||
self.stored_msg_ids = []
|
||||
|
||||
self.ping_task = None
|
||||
self.ping_task_event = asyncio.Event()
|
||||
|
||||
@ -224,9 +226,13 @@ class Session:
|
||||
BytesIO(packet),
|
||||
self.session_id,
|
||||
self.auth_key,
|
||||
self.auth_key_id
|
||||
self.auth_key_id,
|
||||
self.stored_msg_ids
|
||||
)
|
||||
|
||||
if data is None:
|
||||
return
|
||||
|
||||
messages = (
|
||||
data.body.messages
|
||||
if isinstance(data.body, MsgContainer)
|
||||
|
Loading…
Reference in New Issue
Block a user