Merge branch 'mtproto-checks'
# Conflicts: # pyrogram/errors/__init__.py # pyrogram/session/session.py
This commit is contained in:
commit
a3fab6af4b
@ -35,6 +35,7 @@ import pyrogram
|
|||||||
from pyrogram import raw
|
from pyrogram import raw
|
||||||
from pyrogram import utils
|
from pyrogram import utils
|
||||||
from pyrogram.crypto import aes
|
from pyrogram.crypto import aes
|
||||||
|
from pyrogram.errors import CDNFileHashMismatch
|
||||||
from pyrogram.errors import (
|
from pyrogram.errors import (
|
||||||
SessionPasswordNeeded,
|
SessionPasswordNeeded,
|
||||||
VolumeLocNotFound, ChannelPrivate,
|
VolumeLocNotFound, ChannelPrivate,
|
||||||
@ -1009,7 +1010,7 @@ class Client(Methods, Scaffold):
|
|||||||
# https://core.telegram.org/cdn#verifying-files
|
# https://core.telegram.org/cdn#verifying-files
|
||||||
for i, h in enumerate(hashes):
|
for i, h in enumerate(hashes):
|
||||||
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
cdn_chunk = decrypted_chunk[h.limit * i: h.limit * (i + 1)]
|
||||||
assert h.hash == sha256(cdn_chunk).digest(), f"Invalid CDN hash part {i}"
|
CDNFileHashMismatch.check(h.hash == sha256(cdn_chunk).digest())
|
||||||
|
|
||||||
f.write(decrypted_chunk)
|
f.write(decrypted_chunk)
|
||||||
|
|
||||||
|
@ -16,12 +16,18 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import bisect
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from os import urandom
|
from os import urandom
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from pyrogram.errors import SecurityCheckMismatch
|
||||||
from pyrogram.raw.core import Message, Long
|
from pyrogram.raw.core import Message, Long
|
||||||
from . import aes
|
from . import aes
|
||||||
|
from ..session.internals import MsgId
|
||||||
|
|
||||||
|
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
|
||||||
|
|
||||||
|
|
||||||
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
|
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
|
||||||
@ -49,16 +55,22 @@ def pack(message: Message, salt: int, session_id: bytes, auth_key: bytes, auth_k
|
|||||||
return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv)
|
return auth_key_id + msg_key + aes.ige256_encrypt(data + padding, aes_key, aes_iv)
|
||||||
|
|
||||||
|
|
||||||
def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -> Message:
|
def unpack(
|
||||||
assert b.read(8) == auth_key_id, b.getvalue()
|
b: BytesIO,
|
||||||
|
session_id: bytes,
|
||||||
|
auth_key: bytes,
|
||||||
|
auth_key_id: bytes,
|
||||||
|
stored_msg_ids: List[int]
|
||||||
|
) -> Message:
|
||||||
|
SecurityCheckMismatch.check(b.read(8) == auth_key_id)
|
||||||
|
|
||||||
msg_key = b.read(16)
|
msg_key = b.read(16)
|
||||||
aes_key, aes_iv = kdf(auth_key, msg_key, False)
|
aes_key, aes_iv = kdf(auth_key, msg_key, False)
|
||||||
data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv))
|
data = BytesIO(aes.ige256_decrypt(b.read(), aes_key, aes_iv))
|
||||||
data.read(8)
|
data.read(8) # Salt
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id
|
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id
|
||||||
assert data.read(8) == session_id
|
SecurityCheckMismatch.check(data.read(8) == session_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message = Message.read(data)
|
message = Message.read(data)
|
||||||
@ -75,11 +87,41 @@ def unpack(b: BytesIO, session_id: bytes, auth_key: bytes, auth_key_id: bytes) -
|
|||||||
raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}")
|
raise ValueError(f"The server sent an unknown constructor: {hex(e.args[0])}\n{left}")
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
|
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
|
||||||
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
|
|
||||||
# 96 = 88 + 8 (incoming message)
|
# 96 = 88 + 8 (incoming message)
|
||||||
assert msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]
|
SecurityCheckMismatch.check(msg_key == sha256(auth_key[96:96 + 32] + data.getvalue()).digest()[8:24])
|
||||||
|
|
||||||
|
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
|
||||||
|
data.seek(32) # Get to the payload, skip salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4)
|
||||||
|
payload = data.read()
|
||||||
|
padding = payload[message.length:]
|
||||||
|
SecurityCheckMismatch.check(12 <= len(padding) <= 1024)
|
||||||
|
SecurityCheckMismatch.check(len(payload) % 4 == 0)
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
|
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
|
||||||
assert message.msg_id % 2 != 0
|
SecurityCheckMismatch.check(message.msg_id % 2 != 0)
|
||||||
|
|
||||||
|
if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE:
|
||||||
|
del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2]
|
||||||
|
|
||||||
|
if stored_msg_ids:
|
||||||
|
# Ignored message: msg_id is lower than all of the stored values
|
||||||
|
if message.msg_id < stored_msg_ids[0]:
|
||||||
|
raise SecurityCheckMismatch
|
||||||
|
|
||||||
|
# Ignored message: msg_id is equal to any of the stored values
|
||||||
|
if message.msg_id in stored_msg_ids:
|
||||||
|
raise SecurityCheckMismatch
|
||||||
|
|
||||||
|
time_diff = (message.msg_id - MsgId()) / 2 ** 32
|
||||||
|
|
||||||
|
# Ignored message: msg_id belongs over 30 seconds in the future
|
||||||
|
if time_diff > 30:
|
||||||
|
raise SecurityCheckMismatch
|
||||||
|
|
||||||
|
# Ignored message: msg_id belongs over 300 seconds in the past
|
||||||
|
if time_diff < -300:
|
||||||
|
raise SecurityCheckMismatch
|
||||||
|
|
||||||
|
bisect.insort(stored_msg_ids, message.msg_id)
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
@ -39,3 +39,27 @@ class BadMsgNotification(Exception):
|
|||||||
def __init__(self, code):
|
def __init__(self, code):
|
||||||
description = self.descriptions.get(code, "Unknown error code")
|
description = self.descriptions.get(code, "Unknown error code")
|
||||||
super().__init__(f"[{code}] {description}")
|
super().__init__(f"[{code}] {description}")
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityError(Exception):
|
||||||
|
"""Generic security error."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check(cls, cond: bool):
|
||||||
|
"""Raises this exception if the condition is false"""
|
||||||
|
if not cond:
|
||||||
|
raise cls
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityCheckMismatch(SecurityError):
|
||||||
|
"""Raised when a security check mismatch occurs."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("A security check mismatch has occurred.")
|
||||||
|
|
||||||
|
|
||||||
|
class CDNFileHashMismatch(SecurityError):
|
||||||
|
"""Raised when a CDN file hash mismatch occurs."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("A CDN file hash mismatch has occurred.")
|
||||||
|
@ -36,6 +36,4 @@ class AcceptTermsOfService(Scaffold):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert r
|
return bool(r)
|
||||||
|
|
||||||
return True
|
|
||||||
|
@ -42,3 +42,12 @@ class FutureSalt(TLObject):
|
|||||||
salt = Long.read(data)
|
salt = Long.read(data)
|
||||||
|
|
||||||
return FutureSalt(valid_since, valid_until, salt)
|
return FutureSalt(valid_since, valid_until, salt)
|
||||||
|
|
||||||
|
def write(self, *args: Any) -> bytes:
|
||||||
|
b = BytesIO()
|
||||||
|
|
||||||
|
b.write(Int(self.valid_since))
|
||||||
|
b.write(Int(self.valid_until))
|
||||||
|
b.write(Long(self.salt))
|
||||||
|
|
||||||
|
return b.getvalue()
|
||||||
|
@ -45,3 +45,17 @@ class FutureSalts(TLObject):
|
|||||||
salts = [FutureSalt.read(data) for _ in range(count)]
|
salts = [FutureSalt.read(data) for _ in range(count)]
|
||||||
|
|
||||||
return FutureSalts(req_msg_id, now, salts)
|
return FutureSalts(req_msg_id, now, salts)
|
||||||
|
|
||||||
|
def write(self, *args: Any) -> bytes:
|
||||||
|
b = BytesIO()
|
||||||
|
|
||||||
|
b.write(Long(self.req_msg_id))
|
||||||
|
b.write(Int(self.now))
|
||||||
|
|
||||||
|
count = len(self.salts)
|
||||||
|
b.write(Int(count))
|
||||||
|
|
||||||
|
for salt in self.salts:
|
||||||
|
b.write(salt.write())
|
||||||
|
|
||||||
|
return b.getvalue()
|
||||||
|
@ -27,6 +27,7 @@ import pyrogram
|
|||||||
from pyrogram import raw
|
from pyrogram import raw
|
||||||
from pyrogram.connection import Connection
|
from pyrogram.connection import Connection
|
||||||
from pyrogram.crypto import aes, rsa, prime
|
from pyrogram.crypto import aes, rsa, prime
|
||||||
|
from pyrogram.errors import SecurityCheckMismatch
|
||||||
from pyrogram.raw.core import TLObject, Long, Int
|
from pyrogram.raw.core import TLObject, Long, Int
|
||||||
from .internals import MsgId
|
from .internals import MsgId
|
||||||
|
|
||||||
@ -210,33 +211,33 @@ class Auth:
|
|||||||
# Security checks
|
# Security checks
|
||||||
#######################
|
#######################
|
||||||
|
|
||||||
assert dh_prime == prime.CURRENT_DH_PRIME
|
SecurityCheckMismatch.check(dh_prime == prime.CURRENT_DH_PRIME)
|
||||||
log.debug("DH parameters check: OK")
|
log.debug("DH parameters check: OK")
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
|
# https://core.telegram.org/mtproto/security_guidelines#g-a-and-g-b-validation
|
||||||
g_b = int.from_bytes(g_b, "big")
|
g_b = int.from_bytes(g_b, "big")
|
||||||
assert 1 < g < dh_prime - 1
|
SecurityCheckMismatch.check(1 < g < dh_prime - 1)
|
||||||
assert 1 < g_a < dh_prime - 1
|
SecurityCheckMismatch.check(1 < g_a < dh_prime - 1)
|
||||||
assert 1 < g_b < dh_prime - 1
|
SecurityCheckMismatch.check(1 < g_b < dh_prime - 1)
|
||||||
assert 2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64)
|
SecurityCheckMismatch.check(2 ** (2048 - 64) < g_a < dh_prime - 2 ** (2048 - 64))
|
||||||
assert 2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64)
|
SecurityCheckMismatch.check(2 ** (2048 - 64) < g_b < dh_prime - 2 ** (2048 - 64))
|
||||||
log.debug("g_a and g_b validation: OK")
|
log.debug("g_a and g_b validation: OK")
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
|
# https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-values
|
||||||
answer = server_dh_inner_data.write() # Call .write() to remove padding
|
answer = server_dh_inner_data.write() # Call .write() to remove padding
|
||||||
assert answer_with_hash[:20] == sha1(answer).digest()
|
SecurityCheckMismatch.check(answer_with_hash[:20] == sha1(answer).digest())
|
||||||
log.debug("SHA1 hash values check: OK")
|
log.debug("SHA1 hash values check: OK")
|
||||||
|
|
||||||
# https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
|
# https://core.telegram.org/mtproto/security_guidelines#checking-nonce-server-nonce-and-new-nonce-fields
|
||||||
# 1st message
|
# 1st message
|
||||||
assert nonce == res_pq.nonce
|
SecurityCheckMismatch.check(nonce == res_pq.nonce)
|
||||||
# 2nd message
|
# 2nd message
|
||||||
server_nonce = int.from_bytes(server_nonce, "little", signed=True)
|
server_nonce = int.from_bytes(server_nonce, "little", signed=True)
|
||||||
assert nonce == server_dh_params.nonce
|
SecurityCheckMismatch.check(nonce == server_dh_params.nonce)
|
||||||
assert server_nonce == server_dh_params.server_nonce
|
SecurityCheckMismatch.check(server_nonce == server_dh_params.server_nonce)
|
||||||
# 3rd message
|
# 3rd message
|
||||||
assert nonce == set_client_dh_params_answer.nonce
|
SecurityCheckMismatch.check(nonce == set_client_dh_params_answer.nonce)
|
||||||
assert server_nonce == set_client_dh_params_answer.server_nonce
|
SecurityCheckMismatch.check(server_nonce == set_client_dh_params_answer.server_nonce)
|
||||||
server_nonce = server_nonce.to_bytes(16, "little", signed=True)
|
server_nonce = server_nonce.to_bytes(16, "little", signed=True)
|
||||||
log.debug("Nonce fields check: OK")
|
log.debug("Nonce fields check: OK")
|
||||||
|
|
||||||
|
@ -30,7 +30,8 @@ from pyrogram import raw
|
|||||||
from pyrogram.connection import Connection
|
from pyrogram.connection import Connection
|
||||||
from pyrogram.crypto import mtproto
|
from pyrogram.crypto import mtproto
|
||||||
from pyrogram.errors import (
|
from pyrogram.errors import (
|
||||||
RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, ServiceUnavailable, BadMsgNotification
|
RPCError, InternalServerError, AuthKeyDuplicated, FloodWait, ServiceUnavailable, BadMsgNotification,
|
||||||
|
SecurityCheckMismatch
|
||||||
)
|
)
|
||||||
from pyrogram.raw.all import layer
|
from pyrogram.raw.all import layer
|
||||||
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts
|
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalt, FutureSalts
|
||||||
@ -89,6 +90,8 @@ class Session:
|
|||||||
|
|
||||||
self.results = {}
|
self.results = {}
|
||||||
|
|
||||||
|
self.stored_msg_ids = []
|
||||||
|
|
||||||
self.ping_task = None
|
self.ping_task = None
|
||||||
self.ping_task_event = asyncio.Event()
|
self.ping_task_event = asyncio.Event()
|
||||||
|
|
||||||
@ -205,14 +208,19 @@ class Session:
|
|||||||
await self.start()
|
await self.start()
|
||||||
|
|
||||||
async def handle_packet(self, packet):
|
async def handle_packet(self, packet):
|
||||||
|
try:
|
||||||
data = await self.loop.run_in_executor(
|
data = await self.loop.run_in_executor(
|
||||||
pyrogram.crypto_executor,
|
pyrogram.crypto_executor,
|
||||||
mtproto.unpack,
|
mtproto.unpack,
|
||||||
BytesIO(packet),
|
BytesIO(packet),
|
||||||
self.session_id,
|
self.session_id,
|
||||||
self.auth_key,
|
self.auth_key,
|
||||||
self.auth_key_id
|
self.auth_key_id,
|
||||||
|
self.stored_msg_ids
|
||||||
)
|
)
|
||||||
|
except SecurityCheckMismatch:
|
||||||
|
self.connection.close()
|
||||||
|
return
|
||||||
|
|
||||||
messages = (
|
messages = (
|
||||||
data.body.messages
|
data.body.messages
|
||||||
|
Loading…
Reference in New Issue
Block a user