MTPyroger/pyrogram/session/session.py

431 lines
14 KiB
Python
Raw Normal View History

2017-12-05 11:41:07 +00:00
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017 Dan Tès <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
import platform
import threading
from datetime import timedelta, datetime
2017-12-09 01:21:23 +00:00
from hashlib import sha1, sha256
2017-12-05 11:41:07 +00:00
from io import BytesIO
from os import urandom
from queue import Queue
from threading import Event, Thread
from pyrogram import __copyright__, __license__, __version__
from pyrogram.api import functions, types, core
from pyrogram.api.all import layer
2017-12-18 08:50:41 +00:00
from pyrogram.api.core import Message, Object, MsgContainer, Long, FutureSalt, Int
2017-12-05 11:41:07 +00:00
from pyrogram.api.errors import Error
from pyrogram.connection import Connection
2017-12-09 01:25:14 +00:00
from pyrogram.crypto import IGE, KDF2
2017-12-05 11:41:07 +00:00
from .internals import MsgId, MsgFactory, DataCenter
log = logging.getLogger(__name__)
class Result:
def __init__(self):
self.value = None
self.event = Event()
class Session:
VERSION = __version__
APP_VERSION = "Pyrogram \U0001f525 {}".format(VERSION)
DEVICE_MODEL = "{} {}".format(
platform.python_implementation(),
platform.python_version()
)
SYSTEM_VERSION = "{} {}".format(
platform.system(),
platform.release()
)
INITIAL_SALT = 0x616e67656c696361
WORKERS = 4
2017-12-11 11:55:31 +00:00
WAIT_TIMEOUT = 10
2017-12-05 11:41:07 +00:00
MAX_RETRIES = 5
ACKS_THRESHOLD = 8
PING_INTERVAL = 5
2017-12-09 16:09:39 +00:00
notice_displayed = False
2017-12-05 11:41:07 +00:00
def __init__(self, dc_id: int, test_mode: bool, auth_key: bytes, api_id: str):
2017-12-09 16:09:39 +00:00
if not Session.notice_displayed:
print("Pyrogram v{}, {}".format(__version__, __copyright__))
print("Licensed under the terms of the " + __license__, end="\n\n")
Session.notice_displayed = True
2017-12-05 11:41:07 +00:00
self.connection = Connection(DataCenter(dc_id, test_mode))
self.api_id = api_id
self.auth_key = auth_key
self.auth_key_id = sha1(auth_key).digest()[-8:]
self.msg_id = MsgId()
self.session_id = Long(self.msg_id())
self.msg_factory = MsgFactory(self.msg_id)
self.current_salt = None
self.pending_acks = set()
self.recv_queue = Queue()
self.results = {}
self.ping_thread = None
self.ping_thread_event = Event()
self.next_salt_thread = None
self.next_salt_thread_event = Event()
self.is_connected = Event()
2017-12-08 22:40:29 +00:00
self.update_handler = None
2017-12-05 11:41:07 +00:00
self.total_connections = 0
self.total_messages = 0
self.total_bytes = 0
def start(self):
while True:
try:
self.connection.connect()
for i in range(self.WORKERS):
Thread(target=self.worker, name="Worker#{}".format(i + 1)).start()
Thread(target=self.recv, name="RecvThread").start()
self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT)
self.current_salt = FutureSalt(0, 0, self._send(functions.Ping(0)).new_server_salt)
self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0]
if self.next_salt_thread is not None:
self.next_salt_thread.join()
self.next_salt_thread_event.clear()
self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread")
self.next_salt_thread.start()
terms = self._send(
functions.InvokeWithLayer(
layer,
functions.InitConnection(
self.api_id,
self.DEVICE_MODEL,
self.SYSTEM_VERSION,
self.APP_VERSION,
"en", "", "en",
functions.help.GetTermsOfService(),
)
)
)
if self.ping_thread is not None:
self.ping_thread.join()
self.ping_thread_event.clear()
self.ping_thread = Thread(target=self.ping, name="PingThread")
self.ping_thread.start()
log.info("Connection inited: Layer {}".format(layer))
except (OSError, TimeoutError):
self.stop()
else:
break
self.is_connected.set()
self.total_connections += 1
log.debug("Session started")
return terms.text
def stop(self):
self.is_connected.clear()
self.ping_thread_event.set()
self.next_salt_thread_event.set()
self.connection.close()
for i in range(self.WORKERS):
self.recv_queue.put(None)
log.debug("Session stopped")
def restart(self):
self.stop()
self.start()
2017-12-09 01:25:14 +00:00
# def pack(self, message: Message) -> bytes:
# data = Long(self.current_salt.salt) + self.session_id + message.write()
# msg_key = sha1(data).digest()[-16:]
# aes_key, aes_iv = KDF(self.auth_key, msg_key, True)
# padding = urandom(-len(data) % 16)
#
# return self.auth_key_id + msg_key + IGE.encrypt(data + padding, aes_key, aes_iv)
2017-12-05 11:41:07 +00:00
2017-12-09 01:21:23 +00:00
def pack2(self, message: Message):
data = Long(self.current_salt.salt) + self.session_id + message.write()
# MTProto 2.0 requires a minimum of 12 padding bytes.
# I don't get why it says up to 1024 when what it actually needs after the
# required 12 bytes is just extra 0..15 padding bytes for aes
# TODO: It works, but recheck this. What's the meaning of 12..1024 padding bytes?
padding = urandom(-(len(data) + 12) % 16 + 12)
# 88 = 88 + 0 (outgoing message)
msg_key_large = sha256(self.auth_key[88: 88 + 32] + data + padding).digest()
msg_key = msg_key_large[8:24]
aes_key, aes_iv = KDF2(self.auth_key, msg_key, True)
return self.auth_key_id + msg_key + IGE.encrypt(data + padding, aes_key, aes_iv)
2017-12-09 01:25:14 +00:00
# def unpack(self, b: BytesIO) -> Message:
# assert b.read(8) == self.auth_key_id, b.getvalue()
#
# msg_key = b.read(16)
# aes_key, aes_iv = KDF(self.auth_key, msg_key, False)
# data = BytesIO(IGE.decrypt(b.read(), aes_key, aes_iv))
# data.read(8) # Server salt
#
# # https://core.telegram.org/mtproto/security_guidelines#checking-session-id
# assert data.read(8) == self.session_id
#
# message = Message.read(data)
#
# # https://core.telegram.org/mtproto/security_guidelines#checking-sha1-hash-value-of-msg-key
# # https://core.telegram.org/mtproto/security_guidelines#checking-message-length
# # 32 = salt (8) + session_id (8) + msg_id (8) + seq_no (4) + length (4)
# assert msg_key == sha1(data.getvalue()[:32 + message.length]).digest()[-16:]
#
# # https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
# # TODO: check for lower msg_ids
# assert message.msg_id % 2 != 0
#
# return message
2017-12-05 11:41:07 +00:00
2017-12-09 01:21:23 +00:00
def unpack2(self, b: BytesIO) -> Message:
assert b.read(8) == self.auth_key_id, b.getvalue()
msg_key = b.read(16)
aes_key, aes_iv = KDF2(self.auth_key, msg_key, False)
data = BytesIO(IGE.decrypt(b.read(), aes_key, aes_iv))
data.read(8)
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id
assert data.read(8) == self.session_id
message = Message.read(data)
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
# 96 = 88 + 8 (incoming message)
assert msg_key == sha256(self.auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
# TODO: check for lower msg_ids
assert message.msg_id % 2 != 0
return message
2017-12-05 11:41:07 +00:00
def worker(self):
name = threading.current_thread().name
log.debug("{} started".format(name))
while True:
packet = self.recv_queue.get()
if packet is None:
break
try:
self.unpack_dispatch_and_ack(packet)
except Exception as e:
log.error(e, exc_info=True)
log.debug("{} stopped".format(name))
def unpack_dispatch_and_ack(self, packet: bytes):
# TODO: A better dispatcher
2017-12-09 01:24:21 +00:00
data = self.unpack2(BytesIO(packet))
2017-12-05 11:41:07 +00:00
messages = (
data.body.messages
if isinstance(data.body, MsgContainer)
else [data]
)
log.debug(data)
self.total_bytes += len(packet)
self.total_messages += len(messages)
for i in messages:
if i.seq_no % 2 != 0:
self.pending_acks.add(i.msg_id)
# log.debug("{}".format(type(i.body)))
if isinstance(i.body, (types.MsgDetailedInfo, types.MsgNewDetailedInfo)):
self.pending_acks.add(i.body.answer_msg_id)
continue
2017-12-08 22:40:29 +00:00
if isinstance(i.body, types.NewSessionCreated):
continue
2017-12-05 11:41:07 +00:00
msg_id = None
if isinstance(i.body, (types.BadMsgNotification, types.BadServerSalt)):
msg_id = i.body.bad_msg_id
elif isinstance(i.body, types.RpcResult):
msg_id = i.body.req_msg_id
elif isinstance(i.body, types.Pong):
msg_id = i.body.msg_id
elif isinstance(i.body, core.FutureSalts):
msg_id = i.body.req_msg_id
2017-12-08 22:40:29 +00:00
else:
if self.update_handler:
self.update_handler(i.body)
2017-12-05 11:41:07 +00:00
if msg_id in self.results:
self.results[msg_id].value = getattr(i.body, "result", i.body)
self.results[msg_id].event.set()
# print(
# "This packet bytes: ({}) | Total bytes: ({})\n"
# "This packet messages: ({}) | Total messages: ({})\n"
# "Total connections: ({})".format(
# len(packet), self.total_bytes, len(messages), self.total_messages, self.total_connections
# )
# )
if len(self.pending_acks) >= self.ACKS_THRESHOLD:
2017-12-12 07:38:05 +00:00
log.info("Send {} acks".format(len(self.pending_acks)))
2017-12-05 11:41:07 +00:00
try:
self._send(types.MsgsAck(list(self.pending_acks)), False)
except (OSError, TimeoutError):
pass
else:
self.pending_acks.clear()
def ping(self):
log.debug("PingThread started")
while True:
self.ping_thread_event.wait(self.PING_INTERVAL)
if self.ping_thread_event.is_set():
break
try:
self._send(functions.Ping(0), False)
except (OSError, TimeoutError):
pass
log.debug("PingThread stopped")
def next_salt(self):
log.debug("NextSaltThread started")
while True:
now = datetime.now()
# Seconds to wait until middle-overlap, which is
# 15 minutes before/after the current/next salt end/start time
dt = (self.current_salt.valid_until - now).total_seconds() - 900
log.debug("Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format(
self.current_salt.salt,
dt // 60,
dt % 60,
now + timedelta(seconds=dt)
))
self.next_salt_thread_event.wait(dt)
if self.next_salt_thread_event.is_set():
break
try:
self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0]
except (OSError, TimeoutError):
self.connection.close()
break
log.debug("NextSaltThread stopped")
def recv(self):
log.debug("RecvThread started")
while True:
packet = self.connection.recv()
2017-12-18 08:50:41 +00:00
if packet is None or (len(packet) == 4 and Int.read(BytesIO(packet)) == -404):
2017-12-05 11:41:07 +00:00
if self.is_connected.is_set():
Thread(target=self.restart, name="RestartThread").start()
break
self.recv_queue.put(packet)
log.debug("RecvThread stopped")
def _send(self, data: Object, wait_response: bool = True):
message = self.msg_factory(data)
msg_id = message.msg_id
if wait_response:
self.results[msg_id] = Result()
2017-12-09 01:24:21 +00:00
payload = self.pack2(message)
2017-12-05 11:41:07 +00:00
try:
self.connection.send(payload)
except OSError as e:
self.results.pop(msg_id, None)
raise e
if wait_response:
self.results[msg_id].event.wait(self.WAIT_TIMEOUT)
result = self.results.pop(msg_id).value
if result is None:
raise TimeoutError
elif isinstance(result, types.RpcError):
Error.raise_it(result, type(data))
else:
return result
def send(self, data: Object):
for i in range(self.MAX_RETRIES):
self.is_connected.wait()
try:
return self._send(data)
except (OSError, TimeoutError):
2017-12-11 09:34:14 +00:00
log.warning("Retrying {}".format(type(data)))
2017-12-05 11:41:07 +00:00
continue
else:
return None