diff --git a/pyrogram/client.py b/pyrogram/client.py index 61722ac3..c5c72af8 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -35,6 +35,7 @@ import pyrogram from pyrogram import raw from pyrogram import utils from pyrogram.crypto import aes +from pyrogram.errors import CDNFileHashMismatch from pyrogram.errors import ( SessionPasswordNeeded, VolumeLocNotFound, ChannelPrivate, @@ -1009,7 +1010,7 @@ class Client(Methods, Scaffold): # https://core.telegram.org/cdn#verifying-files for i, h in enumerate(hashes): cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)] - assert h.hash == sha256(cdn_chunk).digest(), f"Invalid CDN hash part {i}" + CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest()) f.write(decrypted_chunk) diff --git a/pyrogram/crypto/mtproto.py b/pyrogram/crypto/mtproto.py index 803db297..2fc3b9f8 100644 --- a/pyrogram/crypto/mtproto.py +++ b/pyrogram/crypto/mtproto.py @@ -16,12 +16,18 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import bisect from hashlib import sha256 from io import BytesIO from os import urandom +from typing import List +from pyrogram.errors import SecurityCheckMismatch from pyrogram.raw.core import Message, Long from . import aes +from ..session.internals import MsgId + +STORED_MSG_IDS_MAX_SIZE = 1000 * 2 def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple: @@ -49,16 +55,22 @@ def pack(message: Message, salt: int, session_id: bytes, auth_key: bytes, auth_k return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv) -def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -> Message: - assert b.read(8) == auth_key_id, b.getvalue() +def unpack( + b: BytesIO, + session_id: bytes, + auth_key: bytes, + auth_key_id: bytes, + stored_msg_ids: List[int] +) -> Message: + SecurityCheckMismatch.check(b.read(8) == auth_key_id) msg_key = b.read(16) aes_key, aes_iv = kdf(auth_key, msg_key, False) data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv)) - data.read(8) + data.read(8) # Salt # https://core.telegram.org/mtproto/security_guidelines#checking-session-id - assert data.read(8) == session_id + SecurityCheckMismatch.check(data.read(8) == session_id) try: message = Message.read(data) @@ -75,11 +87,41 @@ def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) - raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}") # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key - # https://core.telegram.org/mtproto/security_guidelines#checking-message-length # 96 = 88 + 8 (incoming message) - assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24] + SecurityCheckMismatch.check(msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]) + + # https://core.telegram.org/mtproto/security_guidelines#checking-message-length + data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4) + payload = data.read() + padding = payload[message.length:] + SecurityCheckMismatch.check(12 <= len(padding) <= 1024) + SecurityCheckMismatch.check(len(payload) % 4 == 0) # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id - assert message.msg_id % 2 != 0 + SecurityCheckMismatch.check(message.msg_id % 2 != 0) + + if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE: + del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2] + + if stored_msg_ids: + # Ignored message: msg_id is lower than all of the stored values + if message.msg_id < stored_msg_ids[0]: + raise SecurityCheckMismatch + + # Ignored message: msg_id is equal to any of the stored values + if message.msg_id in stored_msg_ids: + raise SecurityCheckMismatch + + time_diff = (message.msg_id - MsgId()) / 2 ** 32 + + # Ignored message: msg_id belongs over 30 seconds in the future + if time_diff > 30: + raise SecurityCheckMismatch + + # Ignored message: msg_id belongs over 300 seconds in the past + if time_diff < -300: + raise SecurityCheckMismatch + + bisect.insort(stored_msg_ids, message.msg_id) return message diff --git a/pyrogram/errors/__init__.py b/pyrogram/errors/__init__.py index 514e7a12..c92f24b1 100644 --- a/pyrogram/errors/__init__.py +++ b/pyrogram/errors/__init__.py @@ -39,3 +39,27 @@ class BadMsgNotification(Exception): def __init__(self, code): description = self.descriptions.get(code, "Unknown error code") super().__init__(f"[{code}] {description}") + + +class SecurityError(Exception): + """Generic security error.""" + + @classmethod + def check(cls, cond: bool): + """Raises this exception if the condition is false""" + if not cond: + raise cls + + +class SecurityCheckMismatch(SecurityError): + """Raised when a security check mismatch occurs.""" + + def __init__(self): + super().__init__("A security check mismatch has occurred.") + + +class CDNFileHashMismatch(SecurityError): + """Raised when a CDN file hash mismatch occurs.""" + + def __init__(self): + super().__init__("A CDN file hash mismatch has occurred.") diff --git a/pyrogram/methods/auth/accept_terms_of_service.py b/pyrogram/methods/auth/accept_terms_of_service.py index b5abab86..c8cfd36d 100644 --- a/pyrogram/methods/auth/accept_terms_of_service.py +++ b/pyrogram/methods/auth/accept_terms_of_service.py @@ -36,6 +36,4 @@ class AcceptTermsOfService(Scaffold): ) ) - assert r - - return True + return bool(r) diff --git a/pyrogram/raw/core/future_salt.py b/pyrogram/raw/core/future_salt.py index 85303d12..54a12963 100644 --- a/pyrogram/raw/core/future_salt.py +++ b/pyrogram/raw/core/future_salt.py @@ -42,3 +42,12 @@ class FutureSalt(TLObject): salt = Long.read(data) return FutureSalt(valid_since, valid_until, salt) + + def write(self, *args: Any) -> bytes: + b = BytesIO() + + b.write(Int(self.valid_since)) + b.write(Int(self.valid_until)) + b.write(Long(self.salt)) + + return b.getvalue() diff --git a/pyrogram/raw/core/future_salts.py b/pyrogram/raw/core/future_salts.py index faa4b741..9fa2f8e9 100644 --- a/pyrogram/raw/core/future_salts.py +++ b/pyrogram/raw/core/future_salts.py @@ -45,3 +45,17 @@ class FutureSalts(TLObject): salts = [FutureSalt.read(data) for _ in range(count)] return FutureSalts(req_msg_id, now, salts) + + def write(self, *args: Any) -> bytes: + b = BytesIO() + + b.write(Long(self.req_msg_id)) + b.write(Int(self.now)) + + count = len(self.salts) + b.write(Int(count)) + + for salt in self.salts: + b.write(salt.write()) + + return b.getvalue() diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index a3e87ff4..6b1ad953 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -27,6 +27,7 @@ import pyrogram from pyrogram import raw from pyrogram.connection import Connection from pyrogram.crypto import aes, rsa, prime +from pyrogram.errors import SecurityCheckMismatch from pyrogram.raw.core import TLObject, Long, Int from .internals import MsgId @@ -210,33 +211,33 @@ class Auth: # Security checks ####################### - assert dh_prime == prime.CURRENT_DH_PRIME + SecurityCheckMismatch.check(dh_prime == prime.CURRENT_DH_PRIME) log.debug("DH parameters check: OK") # https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation g_b = int.from_bytes(g_b, "big") - assert 1 < g < dh_prime - 1 - assert 1 < g_a < dh_prime - 1 - assert 1 < g_b < dh_prime - 1 - assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64) - assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64) + SecurityCheckMismatch.check(1 < g < dh_prime - 1) + SecurityCheckMismatch.check(1 < g_a < dh_prime - 1) + SecurityCheckMismatch.check(1 < g_b < dh_prime - 1) + SecurityCheckMismatch.check(2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)) + SecurityCheckMismatch.check(2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)) log.debug("g_a and g_b validation: OK") # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values answer = server_dh_inner_data.write() # Call .write() to remove padding - assert answer_with_hash[:20] == sha1(answer).digest() + SecurityCheckMismatch.check(answer_with_hash[:20] == sha1(answer).digest()) log.debug("SHA1 hash values check: OK") # https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields # 1st message - assert nonce == res_pq.nonce + SecurityCheckMismatch.check(nonce == res_pq.nonce) # 2nd message server_nonce = int.from_bytes(server_nonce, "little", signed=True) - assert nonce == server_dh_params.nonce - assert server_nonce == server_dh_params.server_nonce + SecurityCheckMismatch.check(nonce == server_dh_params.nonce) + SecurityCheckMismatch.check(server_nonce == server_dh_params.server_nonce) # 3rd message - assert nonce == set_client_dh_params_answer.nonce - assert server_nonce == set_client_dh_params_answer.server_nonce + SecurityCheckMismatch.check(nonce == set_client_dh_params_answer.nonce) + SecurityCheckMismatch.check(server_nonce == set_client_dh_params_answer.server_nonce) server_nonce = server_nonce.to_bytes(16, "little", signed=True) log.debug("Nonce fields check: OK") diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 06df2125..1cf8c1b1 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -30,7 +30,8 @@ from pyrogram import raw from pyrogram.connection import Connection from pyrogram.crypto import mtproto from pyrogram.errors import ( - RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, ServiceUnavailable, BadMsgNotification + RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, ServiceUnavailable, BadMsgNotification, + SecurityCheckMismatch ) from pyrogram.raw.all import layer from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts @@ -89,6 +90,8 @@ class Session: self.results = {} + self.stored_msg_ids = [] + self.ping_task = None self.ping_task_event = asyncio.Event() @@ -205,14 +208,19 @@ class Session: await self.start() async def handle_packet(self, packet): - data = await self.loop.run_in_executor( - pyrogram.crypto_executor, - mtproto.unpack, - BytesIO(packet), - self.session_id, - self.auth_key, - self.auth_key_id - ) + try: + data = await self.loop.run_in_executor( + pyrogram.crypto_executor, + mtproto.unpack, + BytesIO(packet), + self.session_id, + self.auth_key, + self.auth_key_id, + self.stored_msg_ids + ) + except SecurityCheckMismatch: + self.connection.close() + return messages = ( data.body.messages