Use specialized exceptions for handling security checks

This commit is contained in:
Dan 2021-12-16 21:38:24 +01:00
parent a720726479
commit 8aa358129c
5 changed files with 55 additions and 25 deletions

View File

@ -35,6 +35,7 @@ import pyrogram
from pyrogram import raw
from pyrogram import utils
from pyrogram.crypto import aes
from pyrogram.errors import CDNFileHashMismatch
from pyrogram.errors import (
SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate,
@ -1009,7 +1010,7 @@ class Client(Methods, Scaffold):
# https://core.telegram.org/cdn#verifying-files
for i, h in enumerate(hashes):
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
assert h.hash == sha256(cdn_chunk).digest(), f"Invalid CDN hash part {i}"
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
f.write(decrypted_chunk)

View File

@ -22,6 +22,7 @@ from io import BytesIO
from os import urandom
from typing import List
from pyrogram.errors import SecurityCheckMismatch
from pyrogram.raw.core import Message, Long
from . import aes
from ..session.internals import MsgId
@ -61,7 +62,7 @@ def unpack(
auth_key_id: bytes,
stored_msg_ids: List[int]
) -> Message:
assert b.read(8) == auth_key_id
SecurityCheckMismatch.check(b.read(8) == auth_key_id)
msg_key = b.read(16)
aes_key, aes_iv = kdf(auth_key, msg_key, False)
@ -69,7 +70,7 @@ def unpack(
data.read(8) # Salt
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id
assert data.read(8) == session_id
SecurityCheckMismatch.check(data.read(8) == session_id)
try:
message = Message.read(data)
@ -87,17 +88,17 @@ def unpack(
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
# 96 = 88 + 8 (incoming message)
assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]
SecurityCheckMismatch.check(msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24])
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4)
payload = data.read()
padding = payload[message.length:]
assert 12 <= len(padding) <= 1024
assert len(payload) % 4 == 0
SecurityCheckMismatch.check(12 <= len(padding) <= 1024)
SecurityCheckMismatch.check(len(payload) % 4 == 0)
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
assert message.msg_id % 2 != 0
SecurityCheckMismatch.check(message.msg_id % 2 != 0)
if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE:
del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2]
@ -105,21 +106,21 @@ def unpack(
if stored_msg_ids:
# Ignored message: msg_id is lower than all of the stored values
if message.msg_id < stored_msg_ids[0]:
assert False
SecurityCheckMismatch.check(False)
# Ignored message: msg_id is equal to any of the stored values
if message.msg_id in stored_msg_ids:
assert False
SecurityCheckMismatch.check(False)
time_diff = (message.msg_id - MsgId()) / 2 ** 32
# Ignored message: msg_id belongs over 30 seconds in the future
if time_diff > 30:
assert False
SecurityCheckMismatch.check(False)
# Ignored message: msg_id belongs over 300 seconds in the past
if time_diff < -300:
assert False
SecurityCheckMismatch.check(False)
bisect.insort(stored_msg_ids, message.msg_id)

View File

@ -18,3 +18,27 @@
from .exceptions import *
from .rpc_error import UnknownError
class SecurityError(Exception):
"""Generic security error."""
@classmethod
def check(cls, cond: bool):
"""Raises this exception if the condition is false"""
if not cond:
raise cls
class SecurityCheckMismatch(SecurityError):
"""Raised when a security check mismatch occurs."""
def __init__(self):
super().__init__("A security check mismatch has occurred.")
class CDNFileHashMismatch(SecurityError):
"""Raised when a CDN file hash mismatch occurs."""
def __init__(self):
super().__init__("A CDN file hash mismatch has occurred.")

View File

@ -27,6 +27,7 @@ import pyrogram
from pyrogram import raw
from pyrogram.connection import Connection
from pyrogram.crypto import aes, rsa, prime
from pyrogram.errors import SecurityCheckMismatch
from pyrogram.raw.core import TLObject, Long, Int
from .internals import MsgId
@ -210,33 +211,33 @@ class Auth:
# Security checks
#######################
assert dh_prime == prime.CURRENT_DH_PRIME
SecurityCheckMismatch.check(dh_prime == prime.CURRENT_DH_PRIME)
log.debug("DH parameters check: OK")
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
g_b = int.from_bytes(g_b, "big")
assert 1 < g < dh_prime - 1
assert 1 < g_a < dh_prime - 1
assert 1 < g_b < dh_prime - 1
assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)
assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)
SecurityCheckMismatch.check(1 < g < dh_prime - 1)
SecurityCheckMismatch.check(1 < g_a < dh_prime - 1)
SecurityCheckMismatch.check(1 < g_b < dh_prime - 1)
SecurityCheckMismatch.check(2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64))
SecurityCheckMismatch.check(2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64))
log.debug("g_a and g_b validation: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
answer = server_dh_inner_data.write() # Call .write() to remove padding
assert answer_with_hash[:20] == sha1(answer).digest()
SecurityCheckMismatch.check(answer_with_hash[:20] == sha1(answer).digest())
log.debug("SHA1 hash values check: OK")
# https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
# 1st message
assert nonce == res_pq.nonce
SecurityCheckMismatch.check(nonce == res_pq.nonce)
# 2nd message
server_nonce = int.from_bytes(server_nonce, "little", signed=True)
assert nonce == server_dh_params.nonce
assert server_nonce == server_dh_params.server_nonce
SecurityCheckMismatch.check(nonce == server_dh_params.nonce)
SecurityCheckMismatch.check(server_nonce == server_dh_params.server_nonce)
# 3rd message
assert nonce == set_client_dh_params_answer.nonce
assert server_nonce == set_client_dh_params_answer.server_nonce
SecurityCheckMismatch.check(nonce == set_client_dh_params_answer.nonce)
SecurityCheckMismatch.check(server_nonce == set_client_dh_params_answer.server_nonce)
server_nonce = server_nonce.to_bytes(16, "little", signed=True)
log.debug("Nonce fields check: OK")

View File

@ -29,7 +29,10 @@ from pyrogram import __copyright__, __license__, __version__
from pyrogram import raw
from pyrogram.connection import Connection
from pyrogram.crypto import mtproto
from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, ServiceUnavailable
from pyrogram.errors import (
RPCError, InternalServerError, AuthKeyDuplicated, FloodWait,
ServiceUnavailable, SecurityCheckMismatch
)
from pyrogram.raw.all import layer
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts
from .internals import MsgId, MsgFactory
@ -230,7 +233,7 @@ class Session:
self.auth_key_id,
self.stored_msg_ids
)
except AssertionError:
except SecurityCheckMismatch:
self.connection.close()
return