mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-27 16:45:19 +00:00
Apply security checks to each message in the container
This commit is contained in:
parent
7ee47b220d
commit
cf1e31c413
@ -16,18 +16,13 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import bisect
|
||||
from hashlib import sha256
|
||||
from io import BytesIO
|
||||
from os import urandom
|
||||
from typing import List
|
||||
|
||||
from pyrogram.errors import SecurityCheckMismatch
|
||||
from pyrogram.raw.core import Message, Long
|
||||
from . import aes
|
||||
from ..session.internals import MsgId
|
||||
|
||||
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
|
||||
|
||||
|
||||
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
|
||||
@ -59,8 +54,7 @@ def unpack(
|
||||
b: BytesIO,
|
||||
session_id: bytes,
|
||||
auth_key: bytes,
|
||||
auth_key_id: bytes,
|
||||
stored_msg_ids: List[int]
|
||||
auth_key_id: bytes
|
||||
) -> Message:
|
||||
SecurityCheckMismatch.check(b.read(8) == auth_key_id, "b.read(8) == auth_key_id")
|
||||
|
||||
@ -103,26 +97,4 @@ def unpack(
|
||||
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
|
||||
SecurityCheckMismatch.check(message.msg_id % 2 != 0, "message.msg_id % 2 != 0")
|
||||
|
||||
if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE:
|
||||
del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2]
|
||||
|
||||
if stored_msg_ids:
|
||||
if message.msg_id < stored_msg_ids[0]:
|
||||
raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
|
||||
|
||||
if message.msg_id in stored_msg_ids:
|
||||
raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
|
||||
|
||||
time_diff = (message.msg_id - MsgId()) / 2 ** 32
|
||||
|
||||
if time_diff > 30:
|
||||
raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
|
||||
"Most likely the client time has to be synchronized.")
|
||||
|
||||
if time_diff < -300:
|
||||
raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
|
||||
"Most likely the client time has to be synchronized.")
|
||||
|
||||
bisect.insort(stored_msg_ids, message.msg_id)
|
||||
|
||||
return message
|
||||
|
@ -17,6 +17,7 @@
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import bisect
|
||||
import logging
|
||||
import os
|
||||
from hashlib import sha1
|
||||
@ -32,7 +33,7 @@ from pyrogram.errors import (
|
||||
)
|
||||
from pyrogram.raw.all import layer
|
||||
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
|
||||
from .internals import MsgFactory
|
||||
from .internals import MsgId, MsgFactory
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -50,6 +51,7 @@ class Session:
|
||||
MAX_RETRIES = 10
|
||||
ACKS_THRESHOLD = 10
|
||||
PING_INTERVAL = 5
|
||||
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -176,20 +178,14 @@ class Session:
|
||||
await self.start()
|
||||
|
||||
async def handle_packet(self, packet):
|
||||
try:
|
||||
data = await self.loop.run_in_executor(
|
||||
pyrogram.crypto_executor,
|
||||
mtproto.unpack,
|
||||
BytesIO(packet),
|
||||
self.session_id,
|
||||
self.auth_key,
|
||||
self.auth_key_id,
|
||||
self.stored_msg_ids
|
||||
self.auth_key_id
|
||||
)
|
||||
except SecurityCheckMismatch as e:
|
||||
log.warning("Discarding packet: %s", e)
|
||||
await self.connection.close()
|
||||
return
|
||||
|
||||
messages = (
|
||||
data.body.messages
|
||||
@ -206,6 +202,33 @@ class Session:
|
||||
else:
|
||||
self.pending_acks.add(msg.msg_id)
|
||||
|
||||
try:
|
||||
if len(self.stored_msg_ids) > Session.STORED_MSG_IDS_MAX_SIZE:
|
||||
del self.stored_msg_ids[:Session.STORED_MSG_IDS_MAX_SIZE // 2]
|
||||
|
||||
if self.stored_msg_ids:
|
||||
if msg.msg_id < self.stored_msg_ids[0]:
|
||||
raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
|
||||
|
||||
if msg.msg_id in self.stored_msg_ids:
|
||||
raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
|
||||
|
||||
time_diff = (msg.msg_id - MsgId()) / 2 ** 32
|
||||
|
||||
if time_diff > 30:
|
||||
raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
|
||||
"Most likely the client time has to be synchronized.")
|
||||
|
||||
if time_diff < -300:
|
||||
raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
|
||||
"Most likely the client time has to be synchronized.")
|
||||
except SecurityCheckMismatch as e:
|
||||
log.warning("Discarding packet: %s", e)
|
||||
await self.connection.close()
|
||||
return
|
||||
else:
|
||||
bisect.insort(self.stored_msg_ids, msg.msg_id)
|
||||
|
||||
if isinstance(msg.body, (raw.types.MsgDetailedInfo, raw.types.MsgNewDetailedInfo)):
|
||||
self.pending_acks.add(msg.body.answer_msg_id)
|
||||
continue
|
||||
@ -323,7 +346,7 @@ class Session:
|
||||
RPCError.raise_it(result, type(data))
|
||||
|
||||
if isinstance(result, raw.types.BadMsgNotification):
|
||||
raise BadMsgNotification(result.error_code)
|
||||
log.warning("%s: %s", BadMsgNotification.__name__, BadMsgNotification(result.error_code))
|
||||
|
||||
if isinstance(result, raw.types.BadServerSalt):
|
||||
self.salt = result.new_server_salt
|
||||
|
Loading…
Reference in New Issue
Block a user