mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-28 00:56:19 +00:00
Apply security checks to each message in the container
This commit is contained in:
parent
7ee47b220d
commit
cf1e31c413
@ -16,18 +16,13 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import bisect
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from os import urandom
|
from os import urandom
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from pyrogram.errors import SecurityCheckMismatch
|
from pyrogram.errors import SecurityCheckMismatch
|
||||||
from pyrogram.raw.core import Message, Long
|
from pyrogram.raw.core import Message, Long
|
||||||
from . import aes
|
from . import aes
|
||||||
from ..session.internals import MsgId
|
|
||||||
|
|
||||||
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
|
|
||||||
|
|
||||||
|
|
||||||
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
|
def kdf(auth_key: bytes, msg_key: bytes, outgoing: bool) -> tuple:
|
||||||
@ -59,8 +54,7 @@ def unpack(
|
|||||||
b: BytesIO,
|
b: BytesIO,
|
||||||
session_id: bytes,
|
session_id: bytes,
|
||||||
auth_key: bytes,
|
auth_key: bytes,
|
||||||
auth_key_id: bytes,
|
auth_key_id: bytes
|
||||||
stored_msg_ids: List[int]
|
|
||||||
) -> Message:
|
) -> Message:
|
||||||
SecurityCheckMismatch.check(b.read(8) == auth_key_id, "b.read(8) == auth_key_id")
|
SecurityCheckMismatch.check(b.read(8) == auth_key_id, "b.read(8) == auth_key_id")
|
||||||
|
|
||||||
@ -103,26 +97,4 @@ def unpack(
|
|||||||
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
|
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
|
||||||
SecurityCheckMismatch.check(message.msg_id % 2 != 0, "message.msg_id % 2 != 0")
|
SecurityCheckMismatch.check(message.msg_id % 2 != 0, "message.msg_id % 2 != 0")
|
||||||
|
|
||||||
if len(stored_msg_ids) > STORED_MSG_IDS_MAX_SIZE:
|
|
||||||
del stored_msg_ids[:STORED_MSG_IDS_MAX_SIZE // 2]
|
|
||||||
|
|
||||||
if stored_msg_ids:
|
|
||||||
if message.msg_id < stored_msg_ids[0]:
|
|
||||||
raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
|
|
||||||
|
|
||||||
if message.msg_id in stored_msg_ids:
|
|
||||||
raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
|
|
||||||
|
|
||||||
time_diff = (message.msg_id - MsgId()) / 2 ** 32
|
|
||||||
|
|
||||||
if time_diff > 30:
|
|
||||||
raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
|
|
||||||
"Most likely the client time has to be synchronized.")
|
|
||||||
|
|
||||||
if time_diff < -300:
|
|
||||||
raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
|
|
||||||
"Most likely the client time has to be synchronized.")
|
|
||||||
|
|
||||||
bisect.insort(stored_msg_ids, message.msg_id)
|
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import bisect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
@ -32,7 +33,7 @@ from pyrogram.errors import (
|
|||||||
)
|
)
|
||||||
from pyrogram.raw.all import layer
|
from pyrogram.raw.all import layer
|
||||||
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
|
from pyrogram.raw.core import TLObject, MsgContainer, Int, FutureSalts
|
||||||
from .internals import MsgFactory
|
from .internals import MsgId, MsgFactory
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -50,6 +51,7 @@ class Session:
|
|||||||
MAX_RETRIES = 10
|
MAX_RETRIES = 10
|
||||||
ACKS_THRESHOLD = 10
|
ACKS_THRESHOLD = 10
|
||||||
PING_INTERVAL = 5
|
PING_INTERVAL = 5
|
||||||
|
STORED_MSG_IDS_MAX_SIZE = 1000 * 2
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -176,20 +178,14 @@ class Session:
|
|||||||
await self.start()
|
await self.start()
|
||||||
|
|
||||||
async def handle_packet(self, packet):
|
async def handle_packet(self, packet):
|
||||||
try:
|
data = await self.loop.run_in_executor(
|
||||||
data = await self.loop.run_in_executor(
|
pyrogram.crypto_executor,
|
||||||
pyrogram.crypto_executor,
|
mtproto.unpack,
|
||||||
mtproto.unpack,
|
BytesIO(packet),
|
||||||
BytesIO(packet),
|
self.session_id,
|
||||||
self.session_id,
|
self.auth_key,
|
||||||
self.auth_key,
|
self.auth_key_id
|
||||||
self.auth_key_id,
|
)
|
||||||
self.stored_msg_ids
|
|
||||||
)
|
|
||||||
except SecurityCheckMismatch as e:
|
|
||||||
log.warning("Discarding packet: %s", e)
|
|
||||||
await self.connection.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
messages = (
|
messages = (
|
||||||
data.body.messages
|
data.body.messages
|
||||||
@ -206,6 +202,33 @@ class Session:
|
|||||||
else:
|
else:
|
||||||
self.pending_acks.add(msg.msg_id)
|
self.pending_acks.add(msg.msg_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if len(self.stored_msg_ids) > Session.STORED_MSG_IDS_MAX_SIZE:
|
||||||
|
del self.stored_msg_ids[:Session.STORED_MSG_IDS_MAX_SIZE // 2]
|
||||||
|
|
||||||
|
if self.stored_msg_ids:
|
||||||
|
if msg.msg_id < self.stored_msg_ids[0]:
|
||||||
|
raise SecurityCheckMismatch("The msg_id is lower than all the stored values")
|
||||||
|
|
||||||
|
if msg.msg_id in self.stored_msg_ids:
|
||||||
|
raise SecurityCheckMismatch("The msg_id is equal to any of the stored values")
|
||||||
|
|
||||||
|
time_diff = (msg.msg_id - MsgId()) / 2 ** 32
|
||||||
|
|
||||||
|
if time_diff > 30:
|
||||||
|
raise SecurityCheckMismatch("The msg_id belongs to over 30 seconds in the future. "
|
||||||
|
"Most likely the client time has to be synchronized.")
|
||||||
|
|
||||||
|
if time_diff < -300:
|
||||||
|
raise SecurityCheckMismatch("The msg_id belongs to over 300 seconds in the past. "
|
||||||
|
"Most likely the client time has to be synchronized.")
|
||||||
|
except SecurityCheckMismatch as e:
|
||||||
|
log.warning("Discarding packet: %s", e)
|
||||||
|
await self.connection.close()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
bisect.insort(self.stored_msg_ids, msg.msg_id)
|
||||||
|
|
||||||
if isinstance(msg.body, (raw.types.MsgDetailedInfo, raw.types.MsgNewDetailedInfo)):
|
if isinstance(msg.body, (raw.types.MsgDetailedInfo, raw.types.MsgNewDetailedInfo)):
|
||||||
self.pending_acks.add(msg.body.answer_msg_id)
|
self.pending_acks.add(msg.body.answer_msg_id)
|
||||||
continue
|
continue
|
||||||
@ -323,7 +346,7 @@ class Session:
|
|||||||
RPCError.raise_it(result, type(data))
|
RPCError.raise_it(result, type(data))
|
||||||
|
|
||||||
if isinstance(result, raw.types.BadMsgNotification):
|
if isinstance(result, raw.types.BadMsgNotification):
|
||||||
raise BadMsgNotification(result.error_code)
|
log.warning("%s: %s", BadMsgNotification.__name__, BadMsgNotification(result.error_code))
|
||||||
|
|
||||||
if isinstance(result, raw.types.BadServerSalt):
|
if isinstance(result, raw.types.BadServerSalt):
|
||||||
self.salt = result.new_server_salt
|
self.salt = result.new_server_salt
|
||||||
|
Loading…
Reference in New Issue
Block a user