Close and reestablish the TCP connection in case of mismatch

This commit is contained in:
Dan 2021-12-15 16:02:39 +01:00
parent bc420da0e2
commit 2a1af2b8e9
2 changed files with 22 additions and 17 deletions

View File

@ -20,7 +20,7 @@ import bisect
from hashlib import sha256 from hashlib import sha256
from io import BytesIO from io import BytesIO
from os import urandom from os import urandom
from typing import Optional, List from typing import List, Tuple
from pyrogram.raw.core import Message, Long from pyrogram.raw.core import Message, Long
from . import aes from . import aes
@ -58,7 +58,7 @@ def unpack(
auth_key: bytes, auth_key: bytes,
auth_key_id: bytes, auth_key_id: bytes,
stored_msg_ids: List[int] stored_msg_ids: List[int]
) -> Optional[Message]: ) -> Tuple[Message, bool]:
assert b.read(8) == auth_key_id, b.getvalue() assert b.read(8) == auth_key_id, b.getvalue()
msg_key = b.read(16) msg_key = b.read(16)
@ -103,22 +103,22 @@ def unpack(
if stored_msg_ids: if stored_msg_ids:
# Ignored message: msg_id is lower than all of the stored values # Ignored message: msg_id is lower than all of the stored values
if message.msg_id < stored_msg_ids[0]: if message.msg_id < stored_msg_ids[0]:
return None return message, False
# Ignored message: msg_id is equal to any of the stored values # Ignored message: msg_id is equal to any of the stored values
if message.msg_id in stored_msg_ids: if message.msg_id in stored_msg_ids:
return None return message, False
time_diff = (message.msg_id - MsgId()) / 2 ** 32 time_diff = (message.msg_id - MsgId()) / 2 ** 32
# Ignored message: msg_id belongs over 30 seconds in the future # Ignored message: msg_id belongs over 30 seconds in the future
if time_diff > 30: if time_diff > 30:
return None return message, False
# Ignored message: msg_id belongs over 300 seconds in the past # Ignored message: msg_id belongs over 300 seconds in the past
if time_diff < -300: if time_diff < -300:
return None return message, False
bisect.insort(stored_msg_ids, message.msg_id) bisect.insort(stored_msg_ids, message.msg_id)
return message return message, True

View File

@ -220,17 +220,22 @@ class Session:
await self.start() await self.start()
async def handle_packet(self, packet): async def handle_packet(self, packet):
data = await self.loop.run_in_executor( try:
pyrogram.crypto_executor, data, ok = await self.loop.run_in_executor(
mtproto.unpack, pyrogram.crypto_executor,
BytesIO(packet), mtproto.unpack,
self.session_id, BytesIO(packet),
self.auth_key, self.session_id,
self.auth_key_id, self.auth_key,
self.stored_msg_ids self.auth_key_id,
) self.stored_msg_ids
)
except AssertionError:
self.connection.close()
return
if data is None: if not ok:
self.connection.close()
return return
messages = ( messages = (