MTPyroger/pyrogram/session/session.py

404 lines
13 KiB
Python
Raw Normal View History

2017-12-05 11:41:07 +00:00
# Pyrogram - Telegram MTProto API Client Library for Python
2019-01-01 11:36:16 +00:00
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
2017-12-05 11:41:07 +00:00
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
2018-06-12 13:56:33 +00:00
import asyncio
2017-12-05 11:41:07 +00:00
import logging
2018-06-12 13:56:33 +00:00
from datetime import datetime, timedelta
from hashlib import sha1
2017-12-05 11:41:07 +00:00
from io import BytesIO
import pyrogram
2017-12-05 11:41:07 +00:00
from pyrogram import __copyright__, __license__, __version__
2018-06-14 11:30:46 +00:00
from pyrogram.api import functions, types
2017-12-05 11:41:07 +00:00
from pyrogram.api.all import layer
2018-06-14 11:30:46 +00:00
from pyrogram.api.core import Object, MsgContainer, Int, Long, FutureSalt, FutureSalts
2018-06-27 22:16:12 +00:00
from pyrogram.api.errors import Error, InternalServerError, AuthKeyDuplicated
2017-12-05 11:41:07 +00:00
from pyrogram.connection import Connection
from pyrogram.crypto import MTProto
2018-06-13 11:37:35 +00:00
from .internals import MsgId, MsgFactory
2017-12-05 11:41:07 +00:00
log = logging.getLogger(__name__)
class Result:
def __init__(self):
self.value = None
2018-06-12 13:56:33 +00:00
self.event = asyncio.Event()
2017-12-05 11:41:07 +00:00
class Session:
INITIAL_SALT = 0x616e67656c696361
2018-02-10 17:28:11 +00:00
NET_WORKERS = 1
2018-04-12 06:29:39 +00:00
WAIT_TIMEOUT = 15
2017-12-05 11:41:07 +00:00
MAX_RETRIES = 5
ACKS_THRESHOLD = 8
PING_INTERVAL = 5
2017-12-09 16:09:39 +00:00
notice_displayed = False
BAD_MSG_DESCRIPTION = {
16: "[16] msg_id too low, the client time has to be synchronized",
17: "[17] msg_id too high, the client time has to be synchronized",
18: "[18] incorrect two lower order msg_id bits, the server expects client message msg_id to be divisible by 4",
19: "[19] container msg_id is the same as msg_id of a previously received message",
20: "[20] message too old, it cannot be verified by the server",
32: "[32] msg_seqno too low",
33: "[33] msg_seqno too high",
34: "[34] an even msg_seqno expected, but odd received",
35: "[35] odd msg_seqno expected, but even received",
48: "[48] incorrect server salt",
64: "[64] invalid container"
}
def __init__(self,
client: pyrogram,
dc_id: int,
auth_key: bytes,
is_media: bool = False,
is_cdn: bool = False):
2017-12-09 16:09:39 +00:00
if not Session.notice_displayed:
print("Pyrogram v{}, {}".format(__version__, __copyright__))
print("Licensed under the terms of the " + __license__, end="\n\n")
Session.notice_displayed = True
2017-12-05 11:41:07 +00:00
self.client = client
2018-05-24 19:19:57 +00:00
self.dc_id = dc_id
self.auth_key = auth_key
self.is_media = is_media
2018-02-08 18:48:01 +00:00
self.is_cdn = is_cdn
2017-12-05 11:41:07 +00:00
2018-05-24 19:19:57 +00:00
self.connection = None
2017-12-05 11:41:07 +00:00
self.auth_key_id = sha1(auth_key).digest()[-8:]
2018-02-18 16:31:00 +00:00
self.session_id = Long(MsgId())
self.msg_factory = MsgFactory()
2017-12-05 11:41:07 +00:00
self.current_salt = None
self.pending_acks = set()
2018-06-12 13:56:33 +00:00
self.recv_queue = asyncio.Queue()
2017-12-05 11:41:07 +00:00
self.results = {}
2018-06-12 13:56:33 +00:00
self.ping_task = None
self.ping_task_event = asyncio.Event()
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
self.next_salt_task = None
self.next_salt_task_event = asyncio.Event()
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
self.net_worker_task = None
self.recv_task = None
2018-06-12 13:56:33 +00:00
self.is_connected = asyncio.Event()
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
async def start(self):
2017-12-05 11:41:07 +00:00
while True:
self.connection = Connection(self.dc_id, self.client.test_mode, self.client.ipv6, self.client.proxy)
2018-05-24 19:19:57 +00:00
2017-12-05 11:41:07 +00:00
try:
2018-06-12 13:56:33 +00:00
await self.connection.connect()
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
self.net_worker_task = asyncio.ensure_future(self.net_worker())
self.recv_task = asyncio.ensure_future(self.recv())
2017-12-05 11:41:07 +00:00
2018-06-28 15:50:37 +00:00
self.current_salt = FutureSalt(0, 0, Session.INITIAL_SALT)
2018-06-12 13:56:33 +00:00
self.current_salt = FutureSalt(0, 0, (await self._send(functions.Ping(0))).new_server_salt)
self.current_salt = (await self._send(functions.GetFutureSalts(1))).salts[0]
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
self.next_salt_task = asyncio.ensure_future(self.next_salt())
2017-12-05 11:41:07 +00:00
2017-12-19 10:38:15 +00:00
if not self.is_cdn:
2018-06-12 13:56:33 +00:00
await self._send(
2017-12-19 10:38:15 +00:00
functions.InvokeWithLayer(
layer,
functions.InitConnection(
api_id=self.client.api_id,
app_version=self.client.app_version,
device_model=self.client.device_model,
system_version=self.client.system_version,
2018-06-26 22:42:32 +00:00
system_lang_code=self.client.lang_code,
lang_code=self.client.lang_code,
lang_pack="",
query=functions.help.GetConfig(),
2017-12-19 10:38:15 +00:00
)
2017-12-05 11:41:07 +00:00
)
2018-03-15 11:03:02 +00:00
)
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
self.ping_task = asyncio.ensure_future(self.ping())
2017-12-05 11:41:07 +00:00
2018-08-28 10:38:02 +00:00
log.info("Session initialized: Layer {}".format(layer))
2018-08-28 10:39:14 +00:00
log.info("Device: {} - {}".format(self.client.device_model, self.client.app_version))
log.info("System: {} ({})".format(self.client.system_version, self.client.lang_code.upper()))
2018-06-27 22:16:12 +00:00
except AuthKeyDuplicated as e:
2018-09-06 22:41:01 +00:00
await self.stop()
2018-06-27 22:16:12 +00:00
raise e
2018-02-21 12:35:17 +00:00
except (OSError, TimeoutError, Error):
2018-06-12 13:56:33 +00:00
await self.stop()
2018-04-12 06:30:52 +00:00
except Exception as e:
2018-06-12 13:56:33 +00:00
await self.stop()
2018-04-12 06:30:52 +00:00
raise e
2017-12-05 11:41:07 +00:00
else:
break
self.is_connected.set()
2018-06-17 16:34:37 +00:00
log.info("Session started")
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
async def stop(self):
2017-12-05 11:41:07 +00:00
self.is_connected.clear()
2018-06-12 13:56:33 +00:00
self.ping_task_event.set()
self.next_salt_task_event.set()
2018-06-12 13:56:33 +00:00
if self.ping_task is not None:
await self.ping_task
2018-06-12 13:56:33 +00:00
if self.next_salt_task is not None:
await self.next_salt_task
2018-06-12 13:56:33 +00:00
self.ping_task_event.clear()
self.next_salt_task_event.clear()
2017-12-05 11:41:07 +00:00
self.connection.close()
if self.recv_task:
await self.recv_task
2017-12-05 11:41:07 +00:00
if self.net_worker_task:
await self.net_worker_task
2018-03-16 10:18:16 +00:00
for i in self.results.values():
i.event.set()
if not self.is_media and callable(self.client.disconnect_handler):
2018-05-23 12:27:17 +00:00
try:
await self.client.disconnect_handler(self.client)
2018-05-23 12:27:17 +00:00
except Exception as e:
log.error(e, exc_info=True)
2018-06-17 16:34:37 +00:00
log.info("Session stopped")
2017-12-05 11:41:07 +00:00
2018-06-13 19:01:28 +00:00
async def restart(self):
2018-06-15 12:30:13 +00:00
await self.stop()
await self.start()
2017-12-09 01:21:23 +00:00
2018-06-12 13:56:33 +00:00
async def net_worker(self):
2018-06-14 01:26:08 +00:00
log.info("NetWorkerTask started")
2017-12-05 11:41:07 +00:00
while True:
2018-06-12 13:56:33 +00:00
packet = await self.recv_queue.get()
2017-12-05 11:41:07 +00:00
if packet is None:
break
try:
data = MTProto.unpack(
BytesIO(packet),
self.session_id,
self.auth_key,
self.auth_key_id
)
2018-05-06 12:55:41 +00:00
messages = (
data.body.messages
if isinstance(data.body, MsgContainer)
else [data]
)
log.debug(data)
for msg in messages:
if msg.seq_no % 2 != 0:
if msg.msg_id in self.pending_acks:
continue
else:
self.pending_acks.add(msg.msg_id)
if isinstance(msg.body, (types.MsgDetailedInfo, types.MsgNewDetailedInfo)):
self.pending_acks.add(msg.body.answer_msg_id)
continue
if isinstance(msg.body, types.NewSessionCreated):
continue
msg_id = None
if isinstance(msg.body, (types.BadMsgNotification, types.BadServerSalt)):
msg_id = msg.body.bad_msg_id
2018-06-14 11:30:46 +00:00
elif isinstance(msg.body, (FutureSalts, types.RpcResult)):
2018-05-06 12:55:41 +00:00
msg_id = msg.body.req_msg_id
elif isinstance(msg.body, types.Pong):
msg_id = msg.body.msg_id
else:
if self.client is not None:
2018-06-14 01:25:15 +00:00
self.client.updates_queue.put_nowait(msg.body)
2018-05-06 12:55:41 +00:00
if msg_id in self.results:
self.results[msg_id].value = getattr(msg.body, "result", msg.body)
self.results[msg_id].event.set()
if len(self.pending_acks) >= self.ACKS_THRESHOLD:
log.info("Send {} acks".format(len(self.pending_acks)))
try:
2018-06-12 13:56:33 +00:00
await self._send(types.MsgsAck(list(self.pending_acks)), False)
2018-05-06 12:55:41 +00:00
except (OSError, TimeoutError):
pass
else:
self.pending_acks.clear()
2017-12-05 11:41:07 +00:00
except Exception as e:
log.error(e, exc_info=True)
2018-06-14 01:26:08 +00:00
log.info("NetWorkerTask stopped")
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
async def ping(self):
2018-06-14 01:26:08 +00:00
log.info("PingTask started")
2017-12-05 11:41:07 +00:00
while True:
2018-06-12 13:56:33 +00:00
try:
await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL)
except asyncio.TimeoutError:
pass
2018-06-14 01:27:30 +00:00
else:
2017-12-05 11:41:07 +00:00
break
try:
2018-06-14 11:05:22 +00:00
await self._send(
functions.PingDelayDisconnect(
0, self.WAIT_TIMEOUT + 10
), False
)
2018-04-12 08:40:17 +00:00
except (OSError, TimeoutError, Error):
2017-12-05 11:41:07 +00:00
pass
2018-06-14 01:26:08 +00:00
log.info("PingTask stopped")
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
async def next_salt(self):
2018-06-14 01:26:08 +00:00
log.info("NextSaltTask started")
2017-12-05 11:41:07 +00:00
while True:
now = datetime.now()
# Seconds to wait until middle-overlap, which is
# 15 minutes before/after the current/next salt end/start time
dt = (self.current_salt.valid_until - now).total_seconds() - 900
2018-06-17 16:35:49 +00:00
log.info("Next salt in {:.0f}m {:.0f}s ({})".format(
dt // 60, dt % 60,
2017-12-05 11:41:07 +00:00
now + timedelta(seconds=dt)
))
2018-06-12 13:56:33 +00:00
try:
await asyncio.wait_for(self.next_salt_task_event.wait(), dt)
except asyncio.TimeoutError:
pass
2018-06-14 01:27:30 +00:00
else:
2017-12-05 11:41:07 +00:00
break
try:
2018-06-12 13:56:33 +00:00
self.current_salt = (await self._send(functions.GetFutureSalts(1))).salts[0]
2018-04-12 06:29:39 +00:00
except (OSError, TimeoutError, Error):
2017-12-05 11:41:07 +00:00
self.connection.close()
break
2018-06-14 01:26:08 +00:00
log.info("NextSaltTask stopped")
2017-12-05 11:41:07 +00:00
2018-06-12 13:56:33 +00:00
async def recv(self):
2018-06-14 01:26:08 +00:00
log.info("RecvTask started")
2017-12-05 11:41:07 +00:00
while True:
2018-06-17 17:20:22 +00:00
packet = await self.connection.recv()
2017-12-05 11:41:07 +00:00
2018-02-18 19:33:33 +00:00
if packet is None or len(packet) == 4:
2018-06-17 16:35:49 +00:00
self.recv_queue.put_nowait(None)
2018-02-18 19:33:33 +00:00
if packet:
log.warning("Server sent \"{}\"".format(Int.read(BytesIO(packet))))
2017-12-05 11:41:07 +00:00
if self.is_connected.is_set():
2018-06-12 13:56:33 +00:00
asyncio.ensure_future(self.restart())
2017-12-05 11:41:07 +00:00
break
2018-06-12 13:56:33 +00:00
self.recv_queue.put_nowait(packet)
2017-12-05 11:41:07 +00:00
2018-06-14 01:26:08 +00:00
log.info("RecvTask stopped")
2017-12-05 11:41:07 +00:00
async def _send(self, data: Object, wait_response: bool = True, timeout: float = WAIT_TIMEOUT):
2017-12-05 11:41:07 +00:00
message = self.msg_factory(data)
msg_id = message.msg_id
if wait_response:
self.results[msg_id] = Result()
payload = MTProto.pack(
message,
self.current_salt.salt,
self.session_id,
self.auth_key,
self.auth_key_id
)
2017-12-05 11:41:07 +00:00
try:
2018-06-12 13:56:33 +00:00
await self.connection.send(payload)
2017-12-05 11:41:07 +00:00
except OSError as e:
self.results.pop(msg_id, None)
raise e
if wait_response:
2018-06-12 13:56:33 +00:00
try:
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
2018-06-12 13:56:33 +00:00
except asyncio.TimeoutError:
pass
2017-12-05 11:41:07 +00:00
result = self.results.pop(msg_id).value
if result is None:
raise TimeoutError
elif isinstance(result, types.RpcError):
Error.raise_it(result, type(data))
elif isinstance(result, types.BadMsgNotification):
raise Exception(self.BAD_MSG_DESCRIPTION.get(
result.error_code,
"Error code {}".format(result.error_code)
))
2017-12-05 11:41:07 +00:00
else:
return result
async def send(self, data: Object, retries: int = MAX_RETRIES, timeout: float = WAIT_TIMEOUT):
2018-06-12 13:56:33 +00:00
try:
await asyncio.wait_for(self.is_connected.wait(), self.WAIT_TIMEOUT)
except asyncio.TimeoutError:
pass
2017-12-05 11:41:07 +00:00
try:
return await self._send(data, timeout=timeout)
except (OSError, TimeoutError, InternalServerError) as e:
if retries == 0:
raise e from None
(log.warning if retries < 3 else log.info)(
"{}: {} Retrying {}".format(
Session.MAX_RETRIES - retries,
datetime.now(), type(data)))
2018-06-12 13:56:33 +00:00
await asyncio.sleep(0.5)
return await self.send(data, retries - 1, timeout)