diff --git a/pyrogram/client.py b/pyrogram/client.py index 63e4b472..36ab4e4c 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -26,6 +26,7 @@ import re import shutil import sys from concurrent.futures.thread import ThreadPoolExecutor +from datetime import datetime, timedelta from hashlib import sha256 from importlib import import_module from io import StringIO, BytesIO @@ -185,6 +186,9 @@ class Client(Methods): WORKERS = min(32, (os.cpu_count() or 0) + 4) # os.cpu_count() can be None WORKDIR = PARENT_DIR + # Interval of seconds in which the updates watchdog will kick in + UPDATES_WATCHDOG_INTERVAL = 5 * 60 + mimetypes = MimeTypes() mimetypes.readfp(StringIO(mime_types)) @@ -273,6 +277,13 @@ class Client(Methods): self.message_cache = Cache(10000) + # Sometimes, for some reason, the server will stop sending updates and will only respond to pings. + # This watchdog will invoke updates.GetState in order to wake up the server and enable it sending updates again + # after some idle time has been detected. + self.updates_watchdog_task = None + self.updates_watchdog_event = asyncio.Event() + self.last_update_time = datetime.now() + self.loop = asyncio.get_event_loop() def __enter__(self): @@ -293,6 +304,18 @@ class Client(Methods): except ConnectionError: pass + async def updates_watchdog(self): + while True: + try: + await asyncio.wait_for(self.updates_watchdog_event.wait(), self.UPDATES_WATCHDOG_INTERVAL) + except asyncio.TimeoutError: + pass + else: + break + + if datetime.now() - self.last_update_time > timedelta(seconds=self.UPDATES_WATCHDOG_INTERVAL): + await self.invoke(raw.functions.updates.GetState()) + async def authorize(self) -> User: if self.bot_token: return await self.sign_in_bot(self.bot_token) @@ -485,6 +508,8 @@ class Client(Methods): return is_min async def handle_updates(self, updates): + self.last_update_time = datetime.now() + if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)): is_min = any(( await self.fetch_peers(updates.users), diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 051d3c52..69cbb813 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -48,7 +48,7 @@ class Connection: await self.protocol.connect(self.address) except OSError as e: log.warning("Unable to connect due to network issues: %s", e) - self.protocol.close() + await self.protocol.close() await asyncio.sleep(1) else: log.info("Connected! %s DC%s%s - IPv%s", @@ -59,17 +59,14 @@ class Connection: break else: log.warning("Connection failed! Trying again...") - raise TimeoutError + raise ConnectionError - def close(self): - self.protocol.close() + async def close(self): + await self.protocol.close() log.info("Disconnected") async def send(self, data: bytes): - try: - await self.protocol.send(data) - except Exception as e: - raise OSError(e) + await self.protocol.send(data) async def recv(self) -> Optional[bytes]: return await self.protocol.recv() diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index beb2e58a..6aff86af 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -20,9 +20,6 @@ import asyncio import ipaddress import logging import socket -import time -from concurrent.futures import ThreadPoolExecutor - import socks log = logging.getLogger(__name__) @@ -34,8 +31,8 @@ class TCP: def __init__(self, ipv6: bool, proxy: dict): self.socket = None - self.reader = None # type: asyncio.StreamReader - self.writer = None # type: asyncio.StreamWriter + self.reader = None + self.writer = None self.lock = asyncio.Lock() self.loop = asyncio.get_event_loop() @@ -63,39 +60,37 @@ class TCP: log.info("Using proxy %s", hostname) else: - self.socket = socks.socksocket( + self.socket = socket.socket( socket.AF_INET6 if ipv6 else socket.AF_INET ) - self.socket.settimeout(TCP.TIMEOUT) + self.socket.setblocking(False) async def connect(self, address: tuple): - # The socket used by the whole logic is blocking and thus it blocks when connecting. - # Offload the task to a thread executor to avoid blocking the main event loop. - with ThreadPoolExecutor(1) as executor: - await self.loop.run_in_executor(executor, self.socket.connect, address) + try: + await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT) + except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 + raise TimeoutError("Connection timed out") self.reader, self.writer = await asyncio.open_connection(sock=self.socket) - def close(self): + async def close(self): try: - self.writer.close() - except AttributeError: - try: - self.socket.shutdown(socket.SHUT_RDWR) - except OSError: - pass - finally: - # A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout. - # This is a workaround that seems to fix the occasional delayed stop of a client. - time.sleep(0.001) - self.socket.close() + if self.writer is not None: + self.writer.close() + await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) + except Exception as e: + log.warning("Close exception: %s %s", type(e).__name__, e) async def send(self, data: bytes): async with self.lock: - self.writer.write(data) - await self.writer.drain() + try: + if self.writer is not None: + self.writer.write(data) + await self.writer.drain() + except Exception as e: + log.warning("Send exception: %s %s", type(e).__name__, e) async def recv(self, length: int = 0): data = b"" diff --git a/pyrogram/methods/auth/initialize.py b/pyrogram/methods/auth/initialize.py index 1e7915e0..7188b668 100644 --- a/pyrogram/methods/auth/initialize.py +++ b/pyrogram/methods/auth/initialize.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import logging import pyrogram @@ -46,4 +47,6 @@ class Initialize: await self.dispatcher.start() + self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog()) + self.is_initialized = True diff --git a/pyrogram/methods/auth/terminate.py b/pyrogram/methods/auth/terminate.py index 5ecb6758..70cfc80e 100644 --- a/pyrogram/methods/auth/terminate.py +++ b/pyrogram/methods/auth/terminate.py @@ -51,4 +51,11 @@ class Terminate: self.media_sessions.clear() + self.updates_watchdog_event.set() + + if self.updates_watchdog_task is not None: + await self.updates_watchdog_task + + self.updates_watchdog_event.clear() + self.is_initialized = False diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index d51e18f8..c5d9cd9a 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -278,4 +278,4 @@ class Auth: else: return auth_key finally: - self.connection.close() + await self.connection.close() diff --git a/pyrogram/session/internals/msg_id.py b/pyrogram/session/internals/msg_id.py index 58e3087c..da2e264f 100644 --- a/pyrogram/session/internals/msg_id.py +++ b/pyrogram/session/internals/msg_id.py @@ -27,9 +27,9 @@ class MsgId: offset = 0 def __new__(cls) -> int: - now = time.time() + now = int(time.time()) cls.offset = (cls.offset + 4) if now == cls.last_time else 0 - msg_id = int(now * 2 ** 32) + cls.offset + msg_id = (now * 2 ** 32) + cls.offset cls.last_time = now return msg_id diff --git a/pyrogram/session/internals/seq_no.py b/pyrogram/session/internals/seq_no.py index 0abc4a2f..79501d98 100644 --- a/pyrogram/session/internals/seq_no.py +++ b/pyrogram/session/internals/seq_no.py @@ -16,19 +16,15 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . -from threading import Lock - class SeqNo: def __init__(self): self.content_related_messages_sent = 0 - self.lock = Lock() def __call__(self, is_content_related: bool) -> int: - with self.lock: - seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0) + seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0) - if is_content_related: - self.content_related_messages_sent += 1 + if is_content_related: + self.content_related_messages_sent += 1 - return seq_no + return seq_no diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 5135af69..df7ae6c4 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -44,11 +44,11 @@ class Result: class Session: - START_TIMEOUT = 1 + START_TIMEOUT = 5 WAIT_TIMEOUT = 15 SLEEP_THRESHOLD = 10 - MAX_RETRIES = 5 - ACKS_THRESHOLD = 8 + MAX_RETRIES = 10 + ACKS_THRESHOLD = 10 PING_INTERVAL = 5 def __init__( @@ -156,14 +156,11 @@ class Session: self.ping_task_event.clear() - self.connection.close() + await self.connection.close() if self.recv_task: await self.recv_task - for i in self.results.values(): - i.event.set() - if not self.is_media and callable(self.client.disconnect_handler): try: await self.client.disconnect_handler(self.client) @@ -189,6 +186,7 @@ class Session: ) except SecurityCheckMismatch as e: log.warning("Discarding packet: %s", e) + await self.connection.close() return messages = ( @@ -284,9 +282,6 @@ class Session: message = self.msg_factory(data) msg_id = message.msg_id - if wait_response: - self.results[msg_id] = Result() - log.debug("Sent: %s", message) payload = await self.loop.run_in_executor( @@ -299,34 +294,35 @@ class Session: self.auth_key_id ) - try: - await self.connection.send(payload) - except OSError as e: - self.results.pop(msg_id, None) - raise e + await self.connection.send(payload) if wait_response: + self.results[msg_id] = Result() + try: await asyncio.wait_for(self.results[msg_id].event.wait(), timeout) except asyncio.TimeoutError: pass - finally: - result = self.results.pop(msg_id).value + + result = self.results.pop(msg_id).value if result is None: raise TimeoutError("Request timed out") - elif isinstance(result, raw.types.RpcError): + + if isinstance(result, raw.types.RpcError): if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)): data = data.query RPCError.raise_it(result, type(data)) - elif isinstance(result, raw.types.BadMsgNotification): + + if isinstance(result, raw.types.BadMsgNotification): raise BadMsgNotification(result.error_code) - elif isinstance(result, raw.types.BadServerSalt): + + if isinstance(result, raw.types.BadServerSalt): self.salt = result.new_server_salt return await self.send(data, wait_response, timeout) - else: - return result + + return result async def invoke( self, diff --git a/pyrogram/storage/file_storage.py b/pyrogram/storage/file_storage.py index 986787cd..aebe9176 100644 --- a/pyrogram/storage/file_storage.py +++ b/pyrogram/storage/file_storage.py @@ -38,13 +38,13 @@ class FileStorage(SQLiteStorage): version = self.version() if version == 1: - with self.lock, self.conn: + with self.conn: self.conn.execute("DELETE FROM peers") version += 1 if version == 2: - with self.lock, self.conn: + with self.conn: self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER") version += 1 @@ -63,10 +63,7 @@ class FileStorage(SQLiteStorage): self.update() with self.conn: - try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum - self.conn.execute("VACUUM") - except sqlite3.OperationalError: - pass + self.conn.execute("VACUUM") async def delete(self): os.remove(self.database) diff --git a/pyrogram/storage/sqlite_storage.py b/pyrogram/storage/sqlite_storage.py index 15e5ddc0..e28b9b74 100644 --- a/pyrogram/storage/sqlite_storage.py +++ b/pyrogram/storage/sqlite_storage.py @@ -19,7 +19,6 @@ import inspect import sqlite3 import time -from threading import Lock from typing import List, Tuple, Any from pyrogram import raw @@ -98,10 +97,9 @@ class SQLiteStorage(Storage): super().__init__(name) self.conn = None # type: sqlite3.Connection - self.lock = Lock() def create(self): - with self.lock, self.conn: + with self.conn: self.conn.executescript(SCHEMA) self.conn.execute( @@ -119,24 +117,20 @@ class SQLiteStorage(Storage): async def save(self): await self.date(int(time.time())) - - with self.lock: - self.conn.commit() + self.conn.commit() async def close(self): - with self.lock: - self.conn.close() + self.conn.close() async def delete(self): raise NotImplementedError async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): - with self.lock: - self.conn.executemany( - "REPLACE INTO peers (id, access_hash, type, username, phone_number)" - "VALUES (?, ?, ?, ?, ?)", - peers - ) + self.conn.executemany( + "REPLACE INTO peers (id, access_hash, type, username, phone_number)" + "VALUES (?, ?, ?, ?, ?)", + peers + ) async def get_peer_by_id(self, peer_id: int): r = self.conn.execute( @@ -185,7 +179,7 @@ class SQLiteStorage(Storage): def _set(self, value: Any): attr = inspect.stack()[2].function - with self.lock, self.conn: + with self.conn: self.conn.execute( f"UPDATE sessions SET {attr} = ?", (value,) @@ -221,7 +215,7 @@ class SQLiteStorage(Storage): "SELECT number FROM version" ).fetchone()[0] else: - with self.lock, self.conn: + with self.conn: self.conn.execute( "UPDATE version SET number = ?", (value,)