diff --git a/pyrogram/client.py b/pyrogram/client.py index b54794fa..8867d2f9 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -35,6 +35,7 @@ import pyrogram from pyrogram import raw from pyrogram import utils from pyrogram.crypto import aes +from pyrogram.errors import CDNFileHashMismatch from pyrogram.errors import ( SessionPasswordNeeded, VolumeLocNotFound, ChannelPrivate, @@ -1009,7 +1010,7 @@ class Client(Methods, Scaffold): # https://core.telegram.org/cdn#verifying-files for i, h in enumerate(hashes): cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] - assert h.hash == sha256(cdn_chunk).digest(), f"Invalid CDN hash part {i}" + CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest()) f.write(decrypted_chunk) diff --git a/pyrogram/crypto/mtproto.py b/pyrogram/crypto/mtproto.py index f9b885f6..ccea119c 100644 --- a/pyrogram/crypto/mtproto.py +++ b/pyrogram/crypto/mtproto.py @@ -22,6 +22,7 @@ from io import BytesIO from os import urandom from typing import List +from pyrogram.errors import SecurityCheckMismatch from pyrogram.raw.core import Message, Long from . import aes from ..session.internals import MsgId @@ -61,7 +62,7 @@ def unpack( auth_key_id: bytes, stored_msg_ids: List[int] ) -> Message: - assert b.read(8) == auth_key_id + SecurityCheckMismatch.check(b.read(8) == auth_key_id) msg_key = b.read(16) aes_key, aes_iv = kdf(auth_key, msg_key, False) @@ -69,7 +70,7 @@ def unpack( data.read(8) # Salt # https://core.telegram.org/mtproto/security_guidelines#checking-session-id - assert data.read(8) == session_id + SecurityCheckMismatch.check(data.read(8) == session_id) try: message = Message.read(data) @@ -87,17 +88,17 @@ def unpack( # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key # 96 = 88 + 8 (incoming message) - assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] + SecurityCheckMismatch.check(msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]) # https://core.telegram.org/mtproto/security_guidelines#checking-message-length data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4) payload = data.read() padding = payload[message.length:] - assert 12 <= len(padding) <= 1024 - assert len(payload) % 4 == 0 + SecurityCheckMismatch.check(12 <= len(padding) <= 1024) + SecurityCheckMismatch.check(len(payload) % 4 == 0) # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id - assert message.msg_id % 2 != 0 + SecurityCheckMismatch.check(message.msg_id % 2 != 0) if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE: del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2] @@ -105,21 +106,21 @@ def unpack( if stored_msg_ids: # Ignored message: msg_id is lower than all of the stored values if message.msg_id < stored_msg_ids[0]: - assert False + SecurityCheckMismatch.check(False) # Ignored message: msg_id is equal to any of the stored values if message.msg_id in stored_msg_ids: - assert False + SecurityCheckMismatch.check(False) time_diff = (message.msg_id - MsgId()) / 2 ** 32 # Ignored message: msg_id belongs over 30 seconds in the future if time_diff > 30: - assert False + SecurityCheckMismatch.check(False) # Ignored message: msg_id belongs over 300 seconds in the past if time_diff < -300: - assert False + SecurityCheckMismatch.check(False) bisect.insort(stored_msg_ids, message.msg_id) diff --git a/pyrogram/errors/__init__.py b/pyrogram/errors/__init__.py index 1b94700f..5011b080 100644 --- a/pyrogram/errors/__init__.py +++ b/pyrogram/errors/__init__.py @@ -18,3 +18,27 @@ from .exceptions import * from .rpc_error import UnknownError + + +class SecurityError(Exception): + """Generic security error.""" + + @classmethod + def check(cls, cond: bool): + """Raises this exception if the condition is false""" + if not cond: + raise cls + + +class SecurityCheckMismatch(SecurityError): + """Raised when a security check mismatch occurs.""" + + def __init__(self): + super().__init__("A security check mismatch has occurred.") + + +class CDNFileHashMismatch(SecurityError): + """Raised when a CDN file hash mismatch occurs.""" + + def __init__(self): + super().__init__("A CDN file hash mismatch has occurred.") diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index a3e87ff4..6b1ad953 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -27,6 +27,7 @@ import pyrogram from pyrogram import raw from pyrogram.connection import Connection from pyrogram.crypto import aes, rsa, prime +from pyrogram.errors import SecurityCheckMismatch from pyrogram.raw.core import TLObject, Long, Int from .internals import MsgId @@ -210,33 +211,33 @@ class Auth: # Security checks ####################### - assert dh_prime == prime.CURRENT_DH_PRIME + SecurityCheckMismatch.check(dh_prime == prime.CURRENT_DH_PRIME) log.debug("DH parameters check: OK") # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation g_b = int.from_bytes(g_b, "big") - assert 1 < g < dh_prime - 1 - assert 1 < g_a < dh_prime - 1 - assert 1 < g_b < dh_prime - 1 - assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64) - assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64) + SecurityCheckMismatch.check(1 < g < dh_prime - 1) + SecurityCheckMismatch.check(1 < g_a < dh_prime - 1) + SecurityCheckMismatch.check(1 < g_b < dh_prime - 1) + SecurityCheckMismatch.check(2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)) + SecurityCheckMismatch.check(2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)) log.debug("g_a and g_b validation: OK") # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values answer = server_dh_inner_data.write() # Call .write() to remove padding - assert answer_with_hash[:20] == sha1(answer).digest() + SecurityCheckMismatch.check(answer_with_hash[:20] == sha1(answer).digest()) log.debug("SHA1 hash values check: OK") # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields # 1st message - assert nonce == res_pq.nonce + SecurityCheckMismatch.check(nonce == res_pq.nonce) # 2nd message server_nonce = int.from_bytes(server_nonce, "little", signed=True) - assert nonce == server_dh_params.nonce - assert server_nonce == server_dh_params.server_nonce + SecurityCheckMismatch.check(nonce == server_dh_params.nonce) + SecurityCheckMismatch.check(server_nonce == server_dh_params.server_nonce) # 3rd message - assert nonce == set_client_dh_params_answer.nonce - assert server_nonce == set_client_dh_params_answer.server_nonce + SecurityCheckMismatch.check(nonce == set_client_dh_params_answer.nonce) + SecurityCheckMismatch.check(server_nonce == set_client_dh_params_answer.server_nonce) server_nonce = server_nonce.to_bytes(16, "little", signed=True) log.debug("Nonce fields check: OK") diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 3504c3ad..72f27621 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -29,7 +29,10 @@ from pyrogram import __copyright__, __license__, __version__ from pyrogram import raw from pyrogram.connection import Connection from pyrogram.crypto import mtproto -from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, ServiceUnavailable +from pyrogram.errors import ( + RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, + ServiceUnavailable, SecurityCheckMismatch +) from pyrogram.raw.all import layer from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts from .internals import MsgId, MsgFactory @@ -230,7 +233,7 @@ class Session: self.auth_key_id, self.stored_msg_ids ) - except AssertionError: + except SecurityCheckMismatch: self.connection.close() return