From 9bf742abc0bb6342749ce108fa49dd9c2bedb646 Mon Sep 17 00:00:00 2001
From: Dan <14043624+delivrance@users.noreply.github.com>
Date: Tue, 27 Dec 2022 13:40:42 +0100
Subject: [PATCH] Introduce back some previously reverted changes
---
pyrogram/client.py | 25 +++++++++++++
pyrogram/connection/connection.py | 13 +++----
pyrogram/connection/transport/tcp/tcp.py | 45 +++++++++++-------------
pyrogram/methods/auth/initialize.py | 3 ++
pyrogram/methods/auth/terminate.py | 7 ++++
pyrogram/session/auth.py | 2 +-
pyrogram/session/internals/msg_id.py | 4 +--
pyrogram/session/internals/seq_no.py | 12 +++----
pyrogram/session/session.py | 40 ++++++++++-----------
pyrogram/storage/file_storage.py | 9 ++---
pyrogram/storage/sqlite_storage.py | 26 ++++++--------
11 files changed, 98 insertions(+), 88 deletions(-)
diff --git a/pyrogram/client.py b/pyrogram/client.py
index 63e4b472..36ab4e4c 100644
--- a/pyrogram/client.py
+++ b/pyrogram/client.py
@@ -26,6 +26,7 @@ import re
import shutil
import sys
from concurrent.futures.thread import ThreadPoolExecutor
+from datetime import datetime, timedelta
from hashlib import sha256
from importlib import import_module
from io import StringIO, BytesIO
@@ -185,6 +186,9 @@ class Client(Methods):
WORKERS = min(32, (os.cpu_count() or 0) + 4) # os.cpu_count() can be None
WORKDIR = PARENT_DIR
+ # Interval of seconds in which the updates watchdog will kick in
+ UPDATES_WATCHDOG_INTERVAL = 5 * 60
+
mimetypes = MimeTypes()
mimetypes.readfp(StringIO(mime_types))
@@ -273,6 +277,13 @@ class Client(Methods):
self.message_cache = Cache(10000)
+ # Sometimes, for some reason, the server will stop sending updates and will only respond to pings.
+ # This watchdog will invoke updates.GetState in order to wake up the server and enable it sending updates again
+ # after some idle time has been detected.
+ self.updates_watchdog_task = None
+ self.updates_watchdog_event = asyncio.Event()
+ self.last_update_time = datetime.now()
+
self.loop = asyncio.get_event_loop()
def __enter__(self):
@@ -293,6 +304,18 @@ class Client(Methods):
except ConnectionError:
pass
+ async def updates_watchdog(self):
+ while True:
+ try:
+ await asyncio.wait_for(self.updates_watchdog_event.wait(), self.UPDATES_WATCHDOG_INTERVAL)
+ except asyncio.TimeoutError:
+ pass
+ else:
+ break
+
+ if datetime.now() - self.last_update_time > timedelta(seconds=self.UPDATES_WATCHDOG_INTERVAL):
+ await self.invoke(raw.functions.updates.GetState())
+
async def authorize(self) -> User:
if self.bot_token:
return await self.sign_in_bot(self.bot_token)
@@ -485,6 +508,8 @@ class Client(Methods):
return is_min
async def handle_updates(self, updates):
+ self.last_update_time = datetime.now()
+
if isinstance(updates, (raw.types.Updates, raw.types.UpdatesCombined)):
is_min = any((
await self.fetch_peers(updates.users),
diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py
index 051d3c52..69cbb813 100644
--- a/pyrogram/connection/connection.py
+++ b/pyrogram/connection/connection.py
@@ -48,7 +48,7 @@ class Connection:
await self.protocol.connect(self.address)
except OSError as e:
log.warning("Unable to connect due to network issues: %s", e)
- self.protocol.close()
+ await self.protocol.close()
await asyncio.sleep(1)
else:
log.info("Connected! %s DC%s%s - IPv%s",
@@ -59,17 +59,14 @@ class Connection:
break
else:
log.warning("Connection failed! Trying again...")
- raise TimeoutError
+ raise ConnectionError
- def close(self):
- self.protocol.close()
+ async def close(self):
+ await self.protocol.close()
log.info("Disconnected")
async def send(self, data: bytes):
- try:
- await self.protocol.send(data)
- except Exception as e:
- raise OSError(e)
+ await self.protocol.send(data)
async def recv(self) -> Optional[bytes]:
return await self.protocol.recv()
diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py
index beb2e58a..6aff86af 100644
--- a/pyrogram/connection/transport/tcp/tcp.py
+++ b/pyrogram/connection/transport/tcp/tcp.py
@@ -20,9 +20,6 @@ import asyncio
import ipaddress
import logging
import socket
-import time
-from concurrent.futures import ThreadPoolExecutor
-
import socks
log = logging.getLogger(__name__)
@@ -34,8 +31,8 @@ class TCP:
def __init__(self, ipv6: bool, proxy: dict):
self.socket = None
- self.reader = None # type: asyncio.StreamReader
- self.writer = None # type: asyncio.StreamWriter
+ self.reader = None
+ self.writer = None
self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()
@@ -63,39 +60,37 @@ class TCP:
log.info("Using proxy %s", hostname)
else:
- self.socket = socks.socksocket(
+ self.socket = socket.socket(
socket.AF_INET6 if ipv6
else socket.AF_INET
)
- self.socket.settimeout(TCP.TIMEOUT)
+ self.socket.setblocking(False)
async def connect(self, address: tuple):
- # The socket used by the whole logic is blocking and thus it blocks when connecting.
- # Offload the task to a thread executor to avoid blocking the main event loop.
- with ThreadPoolExecutor(1) as executor:
- await self.loop.run_in_executor(executor, self.socket.connect, address)
+ try:
+ await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT)
+ except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
+ raise TimeoutError("Connection timed out")
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
- def close(self):
+ async def close(self):
try:
- self.writer.close()
- except AttributeError:
- try:
- self.socket.shutdown(socket.SHUT_RDWR)
- except OSError:
- pass
- finally:
- # A tiny sleep placed here helps avoiding .recv(n) hanging until the timeout.
- # This is a workaround that seems to fix the occasional delayed stop of a client.
- time.sleep(0.001)
- self.socket.close()
+ if self.writer is not None:
+ self.writer.close()
+ await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
+ except Exception as e:
+ log.warning("Close exception: %s %s", type(e).__name__, e)
async def send(self, data: bytes):
async with self.lock:
- self.writer.write(data)
- await self.writer.drain()
+ try:
+ if self.writer is not None:
+ self.writer.write(data)
+ await self.writer.drain()
+ except Exception as e:
+ log.warning("Send exception: %s %s", type(e).__name__, e)
async def recv(self, length: int = 0):
data = b""
diff --git a/pyrogram/methods/auth/initialize.py b/pyrogram/methods/auth/initialize.py
index 1e7915e0..7188b668 100644
--- a/pyrogram/methods/auth/initialize.py
+++ b/pyrogram/methods/auth/initialize.py
@@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
+import asyncio
import logging
import pyrogram
@@ -46,4 +47,6 @@ class Initialize:
await self.dispatcher.start()
+ self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog())
+
self.is_initialized = True
diff --git a/pyrogram/methods/auth/terminate.py b/pyrogram/methods/auth/terminate.py
index 5ecb6758..70cfc80e 100644
--- a/pyrogram/methods/auth/terminate.py
+++ b/pyrogram/methods/auth/terminate.py
@@ -51,4 +51,11 @@ class Terminate:
self.media_sessions.clear()
+ self.updates_watchdog_event.set()
+
+ if self.updates_watchdog_task is not None:
+ await self.updates_watchdog_task
+
+ self.updates_watchdog_event.clear()
+
self.is_initialized = False
diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py
index d51e18f8..c5d9cd9a 100644
--- a/pyrogram/session/auth.py
+++ b/pyrogram/session/auth.py
@@ -278,4 +278,4 @@ class Auth:
else:
return auth_key
finally:
- self.connection.close()
+ await self.connection.close()
diff --git a/pyrogram/session/internals/msg_id.py b/pyrogram/session/internals/msg_id.py
index 58e3087c..da2e264f 100644
--- a/pyrogram/session/internals/msg_id.py
+++ b/pyrogram/session/internals/msg_id.py
@@ -27,9 +27,9 @@ class MsgId:
offset = 0
def __new__(cls) -> int:
- now = time.time()
+ now = int(time.time())
cls.offset = (cls.offset + 4) if now == cls.last_time else 0
- msg_id = int(now * 2 ** 32) + cls.offset
+ msg_id = (now * 2 ** 32) + cls.offset
cls.last_time = now
return msg_id
diff --git a/pyrogram/session/internals/seq_no.py b/pyrogram/session/internals/seq_no.py
index 0abc4a2f..79501d98 100644
--- a/pyrogram/session/internals/seq_no.py
+++ b/pyrogram/session/internals/seq_no.py
@@ -16,19 +16,15 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
-from threading import Lock
-
class SeqNo:
def __init__(self):
self.content_related_messages_sent = 0
- self.lock = Lock()
def __call__(self, is_content_related: bool) -> int:
- with self.lock:
- seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0)
+ seq_no = (self.content_related_messages_sent * 2) + (1 if is_content_related else 0)
- if is_content_related:
- self.content_related_messages_sent += 1
+ if is_content_related:
+ self.content_related_messages_sent += 1
- return seq_no
+ return seq_no
diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py
index 5135af69..df7ae6c4 100644
--- a/pyrogram/session/session.py
+++ b/pyrogram/session/session.py
@@ -44,11 +44,11 @@ class Result:
class Session:
- START_TIMEOUT = 1
+ START_TIMEOUT = 5
WAIT_TIMEOUT = 15
SLEEP_THRESHOLD = 10
- MAX_RETRIES = 5
- ACKS_THRESHOLD = 8
+ MAX_RETRIES = 10
+ ACKS_THRESHOLD = 10
PING_INTERVAL = 5
def __init__(
@@ -156,14 +156,11 @@ class Session:
self.ping_task_event.clear()
- self.connection.close()
+ await self.connection.close()
if self.recv_task:
await self.recv_task
- for i in self.results.values():
- i.event.set()
-
if not self.is_media and callable(self.client.disconnect_handler):
try:
await self.client.disconnect_handler(self.client)
@@ -189,6 +186,7 @@ class Session:
)
except SecurityCheckMismatch as e:
log.warning("Discarding packet: %s", e)
+ await self.connection.close()
return
messages = (
@@ -284,9 +282,6 @@ class Session:
message = self.msg_factory(data)
msg_id = message.msg_id
- if wait_response:
- self.results[msg_id] = Result()
-
log.debug("Sent: %s", message)
payload = await self.loop.run_in_executor(
@@ -299,34 +294,35 @@ class Session:
self.auth_key_id
)
- try:
- await self.connection.send(payload)
- except OSError as e:
- self.results.pop(msg_id, None)
- raise e
+ await self.connection.send(payload)
if wait_response:
+ self.results[msg_id] = Result()
+
try:
await asyncio.wait_for(self.results[msg_id].event.wait(), timeout)
except asyncio.TimeoutError:
pass
- finally:
- result = self.results.pop(msg_id).value
+
+ result = self.results.pop(msg_id).value
if result is None:
raise TimeoutError("Request timed out")
- elif isinstance(result, raw.types.RpcError):
+
+ if isinstance(result, raw.types.RpcError):
if isinstance(data, (raw.functions.InvokeWithoutUpdates, raw.functions.InvokeWithTakeout)):
data = data.query
RPCError.raise_it(result, type(data))
- elif isinstance(result, raw.types.BadMsgNotification):
+
+ if isinstance(result, raw.types.BadMsgNotification):
raise BadMsgNotification(result.error_code)
- elif isinstance(result, raw.types.BadServerSalt):
+
+ if isinstance(result, raw.types.BadServerSalt):
self.salt = result.new_server_salt
return await self.send(data, wait_response, timeout)
- else:
- return result
+
+ return result
async def invoke(
self,
diff --git a/pyrogram/storage/file_storage.py b/pyrogram/storage/file_storage.py
index 986787cd..aebe9176 100644
--- a/pyrogram/storage/file_storage.py
+++ b/pyrogram/storage/file_storage.py
@@ -38,13 +38,13 @@ class FileStorage(SQLiteStorage):
version = self.version()
if version == 1:
- with self.lock, self.conn:
+ with self.conn:
self.conn.execute("DELETE FROM peers")
version += 1
if version == 2:
- with self.lock, self.conn:
+ with self.conn:
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
version += 1
@@ -63,10 +63,7 @@ class FileStorage(SQLiteStorage):
self.update()
with self.conn:
- try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
- self.conn.execute("VACUUM")
- except sqlite3.OperationalError:
- pass
+ self.conn.execute("VACUUM")
async def delete(self):
os.remove(self.database)
diff --git a/pyrogram/storage/sqlite_storage.py b/pyrogram/storage/sqlite_storage.py
index 15e5ddc0..e28b9b74 100644
--- a/pyrogram/storage/sqlite_storage.py
+++ b/pyrogram/storage/sqlite_storage.py
@@ -19,7 +19,6 @@
import inspect
import sqlite3
import time
-from threading import Lock
from typing import List, Tuple, Any
from pyrogram import raw
@@ -98,10 +97,9 @@ class SQLiteStorage(Storage):
super().__init__(name)
self.conn = None # type: sqlite3.Connection
- self.lock = Lock()
def create(self):
- with self.lock, self.conn:
+ with self.conn:
self.conn.executescript(SCHEMA)
self.conn.execute(
@@ -119,24 +117,20 @@ class SQLiteStorage(Storage):
async def save(self):
await self.date(int(time.time()))
-
- with self.lock:
- self.conn.commit()
+ self.conn.commit()
async def close(self):
- with self.lock:
- self.conn.close()
+ self.conn.close()
async def delete(self):
raise NotImplementedError
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
- with self.lock:
- self.conn.executemany(
- "REPLACE INTO peers (id, access_hash, type, username, phone_number)"
- "VALUES (?, ?, ?, ?, ?)",
- peers
- )
+ self.conn.executemany(
+ "REPLACE INTO peers (id, access_hash, type, username, phone_number)"
+ "VALUES (?, ?, ?, ?, ?)",
+ peers
+ )
async def get_peer_by_id(self, peer_id: int):
r = self.conn.execute(
@@ -185,7 +179,7 @@ class SQLiteStorage(Storage):
def _set(self, value: Any):
attr = inspect.stack()[2].function
- with self.lock, self.conn:
+ with self.conn:
self.conn.execute(
f"UPDATE sessions SET {attr} = ?",
(value,)
@@ -221,7 +215,7 @@ class SQLiteStorage(Storage):
"SELECT number FROM version"
).fetchone()[0]
else:
- with self.lock, self.conn:
+ with self.conn:
self.conn.execute(
"UPDATE version SET number = ?",
(value,)