402 lines
12 KiB
Python
402 lines
12 KiB
Python
# Pyrogram - Telegram MTProto API Client Library for Python
|
|
# Copyright (C) 2017 Dan Tès <https://github.com/delivrance>
|
|
#
|
|
# This file is part of Pyrogram.
|
|
#
|
|
# Pyrogram is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Lesser General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# Pyrogram is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Lesser General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Lesser General Public License
|
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import logging
|
|
import platform
|
|
import threading
|
|
from datetime import timedelta, datetime
|
|
from hashlib import sha1, sha256
|
|
from io import BytesIO
|
|
from os import urandom
|
|
from queue import Queue
|
|
from threading import Event, Thread
|
|
|
|
from pyrogram import __copyright__, __license__, __version__
|
|
from pyrogram.api import functions, types, core
|
|
from pyrogram.api.all import layer
|
|
from pyrogram.api.core import Message, Object, MsgContainer, Long, FutureSalt, Int
|
|
from pyrogram.api.errors import Error
|
|
from pyrogram.connection import Connection
|
|
from pyrogram.crypto import IGE, KDF
|
|
from .internals import MsgId, MsgFactory, DataCenter
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class Result:
|
|
def __init__(self):
|
|
self.value = None
|
|
self.event = Event()
|
|
|
|
|
|
class Session:
|
|
VERSION = __version__
|
|
APP_VERSION = "Pyrogram \U0001f525 {}".format(VERSION)
|
|
|
|
DEVICE_MODEL = "{} {}".format(
|
|
platform.python_implementation(),
|
|
platform.python_version()
|
|
)
|
|
|
|
SYSTEM_VERSION = "{} {}".format(
|
|
platform.system(),
|
|
platform.release()
|
|
)
|
|
|
|
INITIAL_SALT = 0x616e67656c696361
|
|
|
|
WORKERS = 4
|
|
WAIT_TIMEOUT = 10
|
|
MAX_RETRIES = 5
|
|
ACKS_THRESHOLD = 8
|
|
PING_INTERVAL = 5
|
|
|
|
notice_displayed = False
|
|
|
|
def __init__(self, dc_id: int, test_mode: bool, auth_key: bytes, api_id: str, is_cdn: bool = False):
|
|
if not Session.notice_displayed:
|
|
print("Pyrogram v{}, {}".format(__version__, __copyright__))
|
|
print("Licensed under the terms of the " + __license__, end="\n\n")
|
|
Session.notice_displayed = True
|
|
|
|
self.is_cdn = is_cdn
|
|
|
|
self.connection = Connection(DataCenter(dc_id, test_mode))
|
|
|
|
self.api_id = api_id
|
|
|
|
self.auth_key = auth_key
|
|
self.auth_key_id = sha1(auth_key).digest()[-8:]
|
|
|
|
self.msg_id = MsgId()
|
|
self.session_id = Long(self.msg_id())
|
|
self.msg_factory = MsgFactory(self.msg_id)
|
|
|
|
self.current_salt = None
|
|
|
|
self.pending_acks = set()
|
|
|
|
self.recv_queue = Queue()
|
|
self.results = {}
|
|
|
|
self.ping_thread = None
|
|
self.ping_thread_event = Event()
|
|
|
|
self.next_salt_thread = None
|
|
self.next_salt_thread_event = Event()
|
|
|
|
self.is_connected = Event()
|
|
|
|
self.update_handler = None
|
|
|
|
self.total_connections = 0
|
|
self.total_messages = 0
|
|
self.total_bytes = 0
|
|
|
|
def start(self):
|
|
terms = None
|
|
|
|
while True:
|
|
try:
|
|
self.connection.connect()
|
|
|
|
for i in range(self.WORKERS):
|
|
Thread(target=self.worker, name="Worker#{}".format(i + 1)).start()
|
|
|
|
Thread(target=self.recv, name="RecvThread").start()
|
|
|
|
self.current_salt = FutureSalt(0, 0, self.INITIAL_SALT)
|
|
self.current_salt = FutureSalt(0, 0, self._send(functions.Ping(0)).new_server_salt)
|
|
self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0]
|
|
|
|
self.next_salt_thread = Thread(target=self.next_salt, name="NextSaltThread")
|
|
self.next_salt_thread.start()
|
|
|
|
if not self.is_cdn:
|
|
terms = self._send(
|
|
functions.InvokeWithLayer(
|
|
layer,
|
|
functions.InitConnection(
|
|
self.api_id,
|
|
self.DEVICE_MODEL,
|
|
self.SYSTEM_VERSION,
|
|
self.APP_VERSION,
|
|
"en", "", "en",
|
|
functions.help.GetTermsOfService(),
|
|
)
|
|
)
|
|
).text
|
|
|
|
self.ping_thread = Thread(target=self.ping, name="PingThread")
|
|
self.ping_thread.start()
|
|
|
|
log.info("Connection inited: Layer {}".format(layer))
|
|
except (OSError, TimeoutError):
|
|
self.stop()
|
|
else:
|
|
break
|
|
|
|
self.is_connected.set()
|
|
self.total_connections += 1
|
|
|
|
log.debug("Session started")
|
|
|
|
return terms
|
|
|
|
def stop(self):
|
|
self.is_connected.clear()
|
|
|
|
self.ping_thread_event.set()
|
|
self.next_salt_thread_event.set()
|
|
|
|
self.ping_thread.join()
|
|
self.next_salt_thread.join()
|
|
|
|
self.ping_thread_event.clear()
|
|
self.next_salt_thread_event.clear()
|
|
|
|
self.connection.close()
|
|
|
|
for i in range(self.WORKERS):
|
|
self.recv_queue.put(None)
|
|
|
|
log.debug("Session stopped")
|
|
|
|
def restart(self):
|
|
self.stop()
|
|
self.start()
|
|
|
|
def pack(self, message: Message):
|
|
data = Long(self.current_salt.salt) + self.session_id + message.write()
|
|
# MTProto 2.0 requires a minimum of 12 padding bytes.
|
|
# I don't get why it says up to 1024 when what it actually needs after the
|
|
# required 12 bytes is just extra 0..15 padding bytes for aes
|
|
# TODO: It works, but recheck this. What's the meaning of 12..1024 padding bytes?
|
|
padding = urandom(-(len(data) + 12) % 16 + 12)
|
|
|
|
# 88 = 88 + 0 (outgoing message)
|
|
msg_key_large = sha256(self.auth_key[88: 88 + 32] + data + padding).digest()
|
|
msg_key = msg_key_large[8:24]
|
|
aes_key, aes_iv = KDF(self.auth_key, msg_key, True)
|
|
|
|
return self.auth_key_id + msg_key + IGE.encrypt(data + padding, aes_key, aes_iv)
|
|
|
|
def unpack(self, b: BytesIO) -> Message:
|
|
assert b.read(8) == self.auth_key_id, b.getvalue()
|
|
|
|
msg_key = b.read(16)
|
|
aes_key, aes_iv = KDF(self.auth_key, msg_key, False)
|
|
data = BytesIO(IGE.decrypt(b.read(), aes_key, aes_iv))
|
|
data.read(8)
|
|
|
|
# https://core.telegram.org/mtproto/security_guidelines#checking-session-id
|
|
assert data.read(8) == self.session_id
|
|
|
|
message = Message.read(data)
|
|
|
|
# https://core.telegram.org/mtproto/security_guidelines#checking-sha256-hash-value-of-msg-key
|
|
# https://core.telegram.org/mtproto/security_guidelines#checking-message-length
|
|
# 96 = 88 + 8 (incoming message)
|
|
assert msg_key == sha256(self.auth_key[96:96 + 32] + data.getvalue()).digest()[8:24]
|
|
|
|
# https://core.telegram.org/mtproto/security_guidelines#checking-msg-id
|
|
# TODO: check for lower msg_ids
|
|
assert message.msg_id % 2 != 0
|
|
|
|
return message
|
|
|
|
def worker(self):
|
|
name = threading.current_thread().name
|
|
log.debug("{} started".format(name))
|
|
|
|
while True:
|
|
packet = self.recv_queue.get()
|
|
|
|
if packet is None:
|
|
break
|
|
|
|
try:
|
|
self.unpack_dispatch_and_ack(packet)
|
|
except Exception as e:
|
|
log.error(e, exc_info=True)
|
|
|
|
log.debug("{} stopped".format(name))
|
|
|
|
def unpack_dispatch_and_ack(self, packet: bytes):
|
|
# TODO: A better dispatcher
|
|
data = self.unpack(BytesIO(packet))
|
|
|
|
messages = (
|
|
data.body.messages
|
|
if isinstance(data.body, MsgContainer)
|
|
else [data]
|
|
)
|
|
|
|
log.debug(data)
|
|
|
|
self.total_bytes += len(packet)
|
|
self.total_messages += len(messages)
|
|
|
|
for i in messages:
|
|
if i.seq_no % 2 != 0:
|
|
self.pending_acks.add(i.msg_id)
|
|
|
|
# log.debug("{}".format(type(i.body)))
|
|
|
|
if isinstance(i.body, (types.MsgDetailedInfo, types.MsgNewDetailedInfo)):
|
|
self.pending_acks.add(i.body.answer_msg_id)
|
|
continue
|
|
|
|
if isinstance(i.body, types.NewSessionCreated):
|
|
continue
|
|
|
|
msg_id = None
|
|
|
|
if isinstance(i.body, (types.BadMsgNotification, types.BadServerSalt)):
|
|
msg_id = i.body.bad_msg_id
|
|
elif isinstance(i.body, types.RpcResult):
|
|
msg_id = i.body.req_msg_id
|
|
elif isinstance(i.body, types.Pong):
|
|
msg_id = i.body.msg_id
|
|
elif isinstance(i.body, core.FutureSalts):
|
|
msg_id = i.body.req_msg_id
|
|
else:
|
|
if self.update_handler:
|
|
self.update_handler(i.body)
|
|
|
|
if msg_id in self.results:
|
|
self.results[msg_id].value = getattr(i.body, "result", i.body)
|
|
self.results[msg_id].event.set()
|
|
|
|
# print(
|
|
# "This packet bytes: ({}) | Total bytes: ({})\n"
|
|
# "This packet messages: ({}) | Total messages: ({})\n"
|
|
# "Total connections: ({})".format(
|
|
# len(packet), self.total_bytes, len(messages), self.total_messages, self.total_connections
|
|
# )
|
|
# )
|
|
|
|
if len(self.pending_acks) >= self.ACKS_THRESHOLD:
|
|
log.info("Send {} acks".format(len(self.pending_acks)))
|
|
|
|
try:
|
|
self._send(types.MsgsAck(list(self.pending_acks)), False)
|
|
except (OSError, TimeoutError):
|
|
pass
|
|
else:
|
|
self.pending_acks.clear()
|
|
|
|
def ping(self):
|
|
log.debug("PingThread started")
|
|
|
|
while True:
|
|
self.ping_thread_event.wait(self.PING_INTERVAL)
|
|
|
|
if self.ping_thread_event.is_set():
|
|
break
|
|
|
|
try:
|
|
self._send(functions.Ping(0), False)
|
|
except (OSError, TimeoutError):
|
|
pass
|
|
|
|
log.debug("PingThread stopped")
|
|
|
|
def next_salt(self):
|
|
log.debug("NextSaltThread started")
|
|
|
|
while True:
|
|
now = datetime.now()
|
|
|
|
# Seconds to wait until middle-overlap, which is
|
|
# 15 minutes before/after the current/next salt end/start time
|
|
dt = (self.current_salt.valid_until - now).total_seconds() - 900
|
|
|
|
log.debug("Current salt: {} | Next salt in {:.0f}m {:.0f}s ({})".format(
|
|
self.current_salt.salt,
|
|
dt // 60,
|
|
dt % 60,
|
|
now + timedelta(seconds=dt)
|
|
))
|
|
|
|
self.next_salt_thread_event.wait(dt)
|
|
|
|
if self.next_salt_thread_event.is_set():
|
|
break
|
|
|
|
try:
|
|
self.current_salt = self._send(functions.GetFutureSalts(1)).salts[0]
|
|
except (OSError, TimeoutError):
|
|
self.connection.close()
|
|
break
|
|
|
|
log.debug("NextSaltThread stopped")
|
|
|
|
def recv(self):
|
|
log.debug("RecvThread started")
|
|
|
|
while True:
|
|
packet = self.connection.recv()
|
|
|
|
if packet is None or (len(packet) == 4 and Int.read(BytesIO(packet)) == -404):
|
|
if self.is_connected.is_set():
|
|
Thread(target=self.restart, name="RestartThread").start()
|
|
break
|
|
|
|
self.recv_queue.put(packet)
|
|
|
|
log.debug("RecvThread stopped")
|
|
|
|
def _send(self, data: Object, wait_response: bool = True):
|
|
message = self.msg_factory(data)
|
|
msg_id = message.msg_id
|
|
|
|
if wait_response:
|
|
self.results[msg_id] = Result()
|
|
|
|
payload = self.pack(message)
|
|
|
|
try:
|
|
self.connection.send(payload)
|
|
except OSError as e:
|
|
self.results.pop(msg_id, None)
|
|
raise e
|
|
|
|
if wait_response:
|
|
self.results[msg_id].event.wait(self.WAIT_TIMEOUT)
|
|
result = self.results.pop(msg_id).value
|
|
|
|
if result is None:
|
|
raise TimeoutError
|
|
elif isinstance(result, types.RpcError):
|
|
Error.raise_it(result, type(data))
|
|
else:
|
|
return result
|
|
|
|
def send(self, data: Object):
|
|
for i in range(self.MAX_RETRIES):
|
|
self.is_connected.wait()
|
|
|
|
try:
|
|
return self._send(data)
|
|
except (OSError, TimeoutError):
|
|
log.warning("Retrying {}".format(type(data)))
|
|
continue
|
|
else:
|
|
return None
|