Reorganize Session to make use of the MTProto module

This commit is contained in:
Dan 2018-06-14 03:22:52 +02:00
parent 75121c9c57
commit 11ddf5f99d

View File

@ -19,20 +19,18 @@
import asyncio
import logging
import platform
import threading
from datetime import datetime, timedelta
from hashlib import sha1, sha256
from hashlib import sha1
from io import BytesIO
from os import urandom
import pyrogram
from pyrogram import __copyright__, __license__, __version__
from pyrogram.api import functions, types, core
from pyrogram.api.all import layer
from pyrogram.api.core import Message, Object, MsgContainer, Long, FutureSalt, Int
from pyrogram.api.core import Object, MsgContainer, Long, FutureSalt, Int
from pyrogram.api.errors import Error, InternalServerError
from pyrogram.connection import Connection
from pyrogram.crypto import AES, KDF
from pyrogram.crypto import MTProto
from .internals import MsgId, MsgFactory, DataCenter
log = logging.getLogger(__name__)
@ -58,7 +56,6 @@ class Session:
platform.release()
)
INITIAL_SALT = 0x616e67656c696361
NET_WORKERS = 1
WAIT_TIMEOUT = 15
MAX_RETRIES = 5
@ -137,7 +134,7 @@ class Session:
self.net_worker_task = asyncio.ensure_future(self.net_worker())
self.recv_task = asyncio.ensure_future(self.recv())
self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT)
self.current_salt = FutureSalt(0, 0, MTProto.INITIAL_SALT)
self.current_salt = FutureSalt(0, 0, (await self._send(functions.Ping(0))).new_server_salt)
self.current_salt = (await self._send(functions.GetFutureSalts(1))).salts[0]
@ -215,36 +212,40 @@ class Session:
data = Long(self.current_salt.salt) + self.session_id + message.write()
padding = urandom(-(len(data) + 12) % 16 + 12)
# 88 = 88 + 0 (outgoing message)
msg_key_large = sha256(self.auth_key[88: 88 + 32] + data + padding).digest()
msg_key = msg_key_large[8:24]
aes_key, aes_iv = KDF(self.auth_key, msg_key, True)
return self.auth_key_id + msg_key + AES.ige256_encrypt(data + padding, aes_key, aes_iv)
def unpack(self, b: BytesIO) -> Message:
assert b.read(8) == self.auth_key_id, b.getvalue()
msg_key = b.read(16)
aes_key, aes_iv = KDF(self.auth_key, msg_key, False)
data = BytesIO(AES.ige256_decrypt(b.read(), aes_key, aes_iv))
data.read(8)
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id
assert data.read(8) == self.session_id
message = Message.read(data)
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
# 96 = 88 + 8 (incoming message)
assert msg_key == sha256(self.auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
# TODO: check for lower msg_ids
assert message.msg_id % 2 != 0
return message
# def pack(self, message: Message):
# data = Long(self.current_salt.salt) + self.session_id + message.write()
# padding = urandom(-(len(data) + 12) % 16 + 12)
#
# # 88 = 88 + 0 (outgoing message)
# msg_key_large = sha256(self.auth_key[88: 88 + 32] + data + padding).digest()
# msg_key = msg_key_large[8:24]
# aes_key, aes_iv = KDF(self.auth_key, msg_key, True)
#
# return self.auth_key_id + msg_key + AES.ige256_encrypt(data + padding, aes_key, aes_iv)
#
# def unpack(self, b: BytesIO) -> Message:
# assert b.read(8) == self.auth_key_id, b.getvalue()
#
# msg_key = b.read(16)
# aes_key, aes_iv = KDF(self.auth_key, msg_key, False)
# data = BytesIO(AES.ige256_decrypt(b.read(), aes_key, aes_iv))
# data.read(8)
#
# # https://core.telegram.org/mtproto/security_guidelines#checking-session-id
# assert data.read(8) == self.session_id
#
# message = Message.read(data)
#
# # https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
# # https://core.telegram.org/mtproto/security_guidelines#checking-message-length
# # 96 = 88 + 8 (incoming message)
# assert msg_key == sha256(self.auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]
#
# # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
# # TODO: check for lower msg_ids
# assert message.msg_id % 2 != 0
#
# return message
async def net_worker(self):
name = threading.current_thread().name
@ -257,7 +258,13 @@ class Session:
break
try:
data = self.unpack(BytesIO(packet))
data = MTProto.unpack(
BytesIO(packet),
self.current_salt.salt,
self.session_id,
self.auth_key,
self.auth_key_id
)
messages = (
data.body.messages
@ -391,7 +398,13 @@ class Session:
if wait_response:
self.results[msg_id] = Result()
payload = self.pack(message)
payload = MTProto.pack(
message,
self.current_salt.salt,
self.session_id,
self.auth_key,
self.auth_key_id
)
try:
await self.connection.send(payload)