Use specialized exceptions for handling security checks

This commit is contained in:
Dan 2021-12-16 21:38:24 +01:00
parent a720726479
commit 8aa358129c
5 changed files with 55 additions and 25 deletions

View File

@ -35,6 +35,7 @@ import pyrogram
from pyrogram import raw from pyrogram import raw
from pyrogram import utils from pyrogram import utils
from pyrogram.crypto import aes from pyrogram.crypto import aes
from pyrogram.errors import CDNFileHashMismatch
from pyrogram.errors import ( from pyrogram.errors import (
SessionPasswordNeeded, SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate, VolumeLocNotFound, ChannelPrivate,
@ -1009,7 +1010,7 @@ class Client(Methods, Scaffold):
# https://core.telegram.org/cdn#verifying-files # https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes): for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), f"Invalid CDN hash part {i}" CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
f.write(decrypted_chunk) f.write(decrypted_chunk)

View File

@ -22,6 +22,7 @@ from io import BytesIO
from os import urandom from os import urandom
from typing import List from typing import List
from pyrogram.errors import SecurityCheckMismatch
from pyrogram.raw.core import Message, Long from pyrogram.raw.core import Message, Long
from . import aes from . import aes
from ..session.internals import MsgId from ..session.internals import MsgId
@ -61,7 +62,7 @@ def unpack(
auth_key_id: bytes, auth_key_id: bytes,
stored_msg_ids: List[int] stored_msg_ids: List[int]
) -> Message: ) -> Message:
assert b.read(8) == auth_key_id SecurityCheckMismatch.check(b.read(8) == auth_key_id)
msg_key = b.read(16) msg_key = b.read(16)
aes_key, aes_iv = kdf(auth_key, msg_key, False) aes_key, aes_iv = kdf(auth_key, msg_key, False)
@ -69,7 +70,7 @@ def unpack(
data.read(8) # Salt data.read(8) # Salt
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id # https://core.telegram.org/mtproto/security_guidelines#checking-session-id
assert data.read(8) == session_id SecurityCheckMismatch.check(data.read(8) == session_id)
try: try:
message = Message.read(data) message = Message.read(data)
@ -87,17 +88,17 @@ def unpack(
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
# 96 = 88 + 8 (incoming message) # 96 = 88 + 8 (incoming message)
assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] SecurityCheckMismatch.check(msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24])
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length # https://core.telegram.org/mtproto/security_guidelines#checking-message-length
data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4) data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4)
payload = data.read() payload = data.read()
padding = payload[message.length:] padding = payload[message.length:]
assert 12 <= len(padding) <= 1024 SecurityCheckMismatch.check(12 <= len(padding) <= 1024)
assert len(payload) % 4 == 0 SecurityCheckMismatch.check(len(payload) % 4 == 0)
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
assert message.msg_id % 2 != 0 SecurityCheckMismatch.check(message.msg_id % 2 != 0)
if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE: if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE:
del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2] del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2]
@ -105,21 +106,21 @@ def unpack(
if stored_msg_ids: if stored_msg_ids:
# Ignored message: msg_id is lower than all of the stored values # Ignored message: msg_id is lower than all of the stored values
if message.msg_id < stored_msg_ids[0]: if message.msg_id < stored_msg_ids[0]:
assert False SecurityCheckMismatch.check(False)
# Ignored message: msg_id is equal to any of the stored values # Ignored message: msg_id is equal to any of the stored values
if message.msg_id in stored_msg_ids: if message.msg_id in stored_msg_ids:
assert False SecurityCheckMismatch.check(False)
time_diff = (message.msg_id - MsgId()) / 2 ** 32 time_diff = (message.msg_id - MsgId()) / 2 ** 32
# Ignored message: msg_id belongs over 30 seconds in the future # Ignored message: msg_id belongs over 30 seconds in the future
if time_diff > 30: if time_diff > 30:
assert False SecurityCheckMismatch.check(False)
# Ignored message: msg_id belongs over 300 seconds in the past # Ignored message: msg_id belongs over 300 seconds in the past
if time_diff < -300: if time_diff < -300:
assert False SecurityCheckMismatch.check(False)
bisect.insort(stored_msg_ids, message.msg_id) bisect.insort(stored_msg_ids, message.msg_id)

View File

@ -18,3 +18,27 @@
from .exceptions import * from .exceptions import *
from .rpc_error import UnknownError from .rpc_error import UnknownError
class SecurityError(Exception):
"""Generic security error."""
@classmethod
def check(cls, cond: bool):
"""Raises this exception if the condition is false"""
if not cond:
raise cls
class SecurityCheckMismatch(SecurityError):
"""Raised when a security check mismatch occurs."""
def __init__(self):
super().__init__("A security check mismatch has occurred.")
class CDNFileHashMismatch(SecurityError):
"""Raised when a CDN file hash mismatch occurs."""
def __init__(self):
super().__init__("A CDN file hash mismatch has occurred.")

View File

@ -27,6 +27,7 @@ import pyrogram
from pyrogram import raw from pyrogram import raw
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import aes, rsa, prime from pyrogram.crypto import aes, rsa, prime
from pyrogram.errors import SecurityCheckMismatch
from pyrogram.raw.core import TLObject, Long, Int from pyrogram.raw.core import TLObject, Long, Int
from .internals import MsgId from .internals import MsgId
@ -210,33 +211,33 @@ class Auth:
# Security checks # Security checks
####################### #######################
assert dh_prime == prime.CURRENT_DH_PRIME SecurityCheckMismatch.check(dh_prime == prime.CURRENT_DH_PRIME)
log.debug("DH parameters check: OK") log.debug("DH parameters check: OK")
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
g_b = int.from_bytes(g_b, "big") g_b = int.from_bytes(g_b, "big")
assert 1 < g < dh_prime - 1 SecurityCheckMismatch.check(1 < g < dh_prime - 1)
assert 1 < g_a < dh_prime - 1 SecurityCheckMismatch.check(1 < g_a < dh_prime - 1)
assert 1 < g_b < dh_prime - 1 SecurityCheckMismatch.check(1 < g_b < dh_prime - 1)
assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64) SecurityCheckMismatch.check(2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64))
assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64) SecurityCheckMismatch.check(2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64))
log.debug("g_a and g_b validation: OK") log.debug("g_a and g_b validation: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
answer = server_dh_inner_data.write() # Call .write() to remove padding answer = server_dh_inner_data.write() # Call .write() to remove padding
assert answer_with_hash[:20] == sha1(answer).digest() SecurityCheckMismatch.check(answer_with_hash[:20] == sha1(answer).digest())
log.debug("SHA1 hash values check: OK") log.debug("SHA1 hash values check: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
# 1st message # 1st message
assert nonce == res_pq.nonce SecurityCheckMismatch.check(nonce == res_pq.nonce)
# 2nd message # 2nd message
server_nonce = int.from_bytes(server_nonce, "little", signed=True) server_nonce = int.from_bytes(server_nonce, "little", signed=True)
assert nonce == server_dh_params.nonce SecurityCheckMismatch.check(nonce == server_dh_params.nonce)
assert server_nonce == server_dh_params.server_nonce SecurityCheckMismatch.check(server_nonce == server_dh_params.server_nonce)
# 3rd message # 3rd message
assert nonce == set_client_dh_params_answer.nonce SecurityCheckMismatch.check(nonce == set_client_dh_params_answer.nonce)
assert server_nonce == set_client_dh_params_answer.server_nonce SecurityCheckMismatch.check(server_nonce == set_client_dh_params_answer.server_nonce)
server_nonce = server_nonce.to_bytes(16, "little", signed=True) server_nonce = server_nonce.to_bytes(16, "little", signed=True)
log.debug("Nonce fields check: OK") log.debug("Nonce fields check: OK")

View File

@ -29,7 +29,10 @@ from pyrogram import __copyright__, __license__, __version__
from pyrogram import raw from pyrogram import raw
from pyrogram.connection import Connection from pyrogram.connection import Connection
from pyrogram.crypto import mtproto from pyrogram.crypto import mtproto
from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, ServiceUnavailable from pyrogram.errors import (
RPCError, InternalServerError, AuthKeyDuplicated, FloodWait,
ServiceUnavailable, SecurityCheckMismatch
)
from pyrogram.raw.all import layer from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts
from .internals import MsgId, MsgFactory from .internals import MsgId, MsgFactory
@ -230,7 +233,7 @@ class Session:
self.auth_key_id, self.auth_key_id,
self.stored_msg_ids self.stored_msg_ids
) )
except AssertionError: except SecurityCheckMismatch:
self.connection.close() self.connection.close()
return return