Merge branch 'mtproto-checks'

# Conflicts:
#	pyrogram/errors/__init__.py
#	pyrogram/session/session.py
This commit is contained in:
Dan 2021-12-24 16:28:29 +01:00
commit a3fab6af4b
8 changed files with 129 additions and 32 deletions

View File

@ -35,6 +35,7 @@ import pyrogram
from pyrogram import raw from pyrogram import raw
from pyrogram import utils from pyrogram import utils
from pyrogram.crypto import aes from pyrogram.crypto import aes
from pyrogram.errors import CDNFileHashMismatch
from pyrogram.errors import ( from pyrogram.errors import (
SessionPasswordNeeded, SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate, VolumeLocNotFound, ChannelPrivate,
@ -1009,7 +1010,7 @@ class Client(Methods, Scaffold):
# https://core.telegram.org/cdn#verifying-files # https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes): for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), f"Invalid CDN hash part {i}" CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
f.write(decrypted_chunk) f.write(decrypted_chunk)

View File

@ -16,12 +16,18 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import bisect
from hashlib import sha256 from hashlib import sha256
from io import BytesIO from io import BytesIO
from os import urandom from os import urandom
from typing import List
from pyrogram.errors import SecurityCheckMismatch
from pyrogram.raw.core import Message, Long from pyrogram.raw.core import Message, Long
from . import aes from . import aes
from ..session.internals import MsgId
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple: def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
@ -49,16 +55,22 @@ def pack(message: Message, salt: int, session_id: bytes, auth_key: bytes, auth_k
return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv) return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv)
def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -> Message: def unpack(
assert b.read(8) == auth_key_id, b.getvalue() b: BytesIO,
session_id: bytes,
auth_key: bytes,
auth_key_id: bytes,
stored_msg_ids: List[int]
) -> Message:
SecurityCheckMismatch.check(b.read(8) == auth_key_id)
msg_key = b.read(16) msg_key = b.read(16)
aes_key, aes_iv = kdf(auth_key, msg_key, False) aes_key, aes_iv = kdf(auth_key, msg_key, False)
data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv)) data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv))
data.read(8) data.read(8) # Salt
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id # https://core.telegram.org/mtproto/security_guidelines#checking-session-id
assert data.read(8) == session_id SecurityCheckMismatch.check(data.read(8) == session_id)
try: try:
message = Message.read(data) message = Message.read(data)
@ -75,11 +87,41 @@ def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -
raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}") raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}")
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
# 96 = 88 + 8 (incoming message) # 96 = 88 + 8 (incoming message)
assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] SecurityCheckMismatch.check(msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24])
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4)
payload = data.read()
padding = payload[message.length:]
SecurityCheckMismatch.check(12 <= len(padding) <= 1024)
SecurityCheckMismatch.check(len(payload) % 4 == 0)
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
assert message.msg_id % 2 != 0 SecurityCheckMismatch.check(message.msg_id % 2 != 0)
if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE:
del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2]
if stored_msg_ids:
# Ignored message: msg_id is lower than all of the stored values
if message.msg_id < stored_msg_ids[0]:
raise SecurityCheckMismatch
# Ignored message: msg_id is equal to any of the stored values
if message.msg_id in stored_msg_ids:
raise SecurityCheckMismatch
time_diff = (message.msg_id - MsgId()) / 2 ** 32
# Ignored message: msg_id belongs over 30 seconds in the future
if time_diff > 30:
raise SecurityCheckMismatch
# Ignored message: msg_id belongs over 300 seconds in the past
if time_diff < -300:
raise SecurityCheckMismatch
bisect.insort(stored_msg_ids, message.msg_id)
return message return message

View File

@ -39,3 +39,27 @@ class BadMsgNotification(Exception):
def __init__(self, code): def __init__(self, code):
description = self.descriptions.get(code, "Unknown error code") description = self.descriptions.get(code, "Unknown error code")
super().__init__(f"[{code}] {description}") super().__init__(f"[{code}] {description}")
class SecurityError(Exception):
"""Generic security error."""
@classmethod
def check(cls, cond: bool):
"""Raises this exception if the condition is false"""
if not cond:
raise cls
class SecurityCheckMismatch(SecurityError):
"""Raised when a security check mismatch occurs."""
def __init__(self):
super().__init__("A security check mismatch has occurred.")
class CDNFileHashMismatch(SecurityError):
"""Raised when a CDN file hash mismatch occurs."""
def __init__(self):
super().__init__("A CDN file hash mismatch has occurred.")

View File

@ -36,6 +36,4 @@ class AcceptTermsOfService(Scaffold):
) )
) )
assert r return bool(r)
return True

View File

@ -42,3 +42,12 @@ class FutureSalt(TLObject):
salt = Long.read(data) salt = Long.read(data)
return FutureSalt(valid_since, valid_until, salt) return FutureSalt(valid_since, valid_until, salt)
def write(self, *args: Any) -> bytes:
b = BytesIO()
b.write(Int(self.valid_since))
b.write(Int(self.valid_until))
b.write(Long(self.salt))
return b.getvalue()

View File

@ -45,3 +45,17 @@ class FutureSalts(TLObject):
salts = [FutureSalt.read(data) for _ in range(count)] salts = [FutureSalt.read(data) for _ in range(count)]
return FutureSalts(req_msg_id, now, salts) return FutureSalts(req_msg_id, now, salts)
def write(self, *args: Any) -> bytes:
b = BytesIO()
b.write(Long(self.req_msg_id))
b.write(Int(self.now))
count = len(self.salts)
b.write(Int(count))
for salt in self.salts:
b.write(salt.write())
return b.getvalue()

View File

@ -27,6 +27,7 @@ import pyrogram
from pyrogram import raw from pyrogram import raw
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import aes, rsa, prime from pyrogram.crypto import aes, rsa, prime
from pyrogram.errors import SecurityCheckMismatch
from pyrogram.raw.core import TLObject, Long, Int from pyrogram.raw.core import TLObject, Long, Int
from .internals import MsgId from .internals import MsgId
@ -210,33 +211,33 @@ class Auth:
# Security checks # Security checks
####################### #######################
assert dh_prime == prime.CURRENT_DH_PRIME SecurityCheckMismatch.check(dh_prime == prime.CURRENT_DH_PRIME)
log.debug("DH parameters check: OK") log.debug("DH parameters check: OK")
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
g_b = int.from_bytes(g_b, "big") g_b = int.from_bytes(g_b, "big")
assert 1 < g < dh_prime - 1 SecurityCheckMismatch.check(1 < g < dh_prime - 1)
assert 1 < g_a < dh_prime - 1 SecurityCheckMismatch.check(1 < g_a < dh_prime - 1)
assert 1 < g_b < dh_prime - 1 SecurityCheckMismatch.check(1 < g_b < dh_prime - 1)
assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64) SecurityCheckMismatch.check(2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64))
assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64) SecurityCheckMismatch.check(2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64))
log.debug("g_a and g_b validation: OK") log.debug("g_a and g_b validation: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
answer = server_dh_inner_data.write() # Call .write() to remove padding answer = server_dh_inner_data.write() # Call .write() to remove padding
assert answer_with_hash[:20] == sha1(answer).digest() SecurityCheckMismatch.check(answer_with_hash[:20] == sha1(answer).digest())
log.debug("SHA1 hash values check: OK") log.debug("SHA1 hash values check: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
# 1st message # 1st message
assert nonce == res_pq.nonce SecurityCheckMismatch.check(nonce == res_pq.nonce)
# 2nd message # 2nd message
server_nonce = int.from_bytes(server_nonce, "little", signed=True) server_nonce = int.from_bytes(server_nonce, "little", signed=True)
assert nonce == server_dh_params.nonce SecurityCheckMismatch.check(nonce == server_dh_params.nonce)
assert server_nonce == server_dh_params.server_nonce SecurityCheckMismatch.check(server_nonce == server_dh_params.server_nonce)
# 3rd message # 3rd message
assert nonce == set_client_dh_params_answer.nonce SecurityCheckMismatch.check(nonce == set_client_dh_params_answer.nonce)
assert server_nonce == set_client_dh_params_answer.server_nonce SecurityCheckMismatch.check(server_nonce == set_client_dh_params_answer.server_nonce)
server_nonce = server_nonce.to_bytes(16, "little", signed=True) server_nonce = server_nonce.to_bytes(16, "little", signed=True)
log.debug("Nonce fields check: OK") log.debug("Nonce fields check: OK")

View File

@ -30,7 +30,8 @@ from pyrogram import raw
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import mtproto from pyrogram.crypto import mtproto
from pyrogram.errors import ( from pyrogram.errors import (
RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, ServiceUnavailable, BadMsgNotification RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, ServiceUnavailable, BadMsgNotification,
SecurityCheckMismatch
) )
from pyrogram.raw.all import layer from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts
@ -89,6 +90,8 @@ class Session:
self.results = {} self.results = {}
self.stored_msg_ids = []
self.ping_task = None self.ping_task = None
self.ping_task_event = asyncio.Event() self.ping_task_event = asyncio.Event()
@ -205,14 +208,19 @@ class Session:
await self.start() await self.start()
async def handle_packet(self, packet): async def handle_packet(self, packet):
data = await self.loop.run_in_executor( try:
pyrogram.crypto_executor, data = await self.loop.run_in_executor(
mtproto.unpack, pyrogram.crypto_executor,
BytesIO(packet), mtproto.unpack,
self.session_id, BytesIO(packet),
self.auth_key, self.session_id,
self.auth_key_id self.auth_key,
) self.auth_key_id,
self.stored_msg_ids
)
except SecurityCheckMismatch:
self.connection.close()
return
messages = ( messages = (
data.body.messages data.body.messages