mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-16 04:35:24 +00:00
Add skip_updates parameter to Client class
This commit is contained in:
parent
befab2f1b5
commit
c16c83abc3
@ -163,6 +163,10 @@ class Client(Methods):
|
||||
Useful for batch programs that don't need to deal with updates.
|
||||
Defaults to False (updates enabled and received).
|
||||
|
||||
skip_updates (``bool``, *optional*):
|
||||
Pass True to skip pending updates that arrived while the client was offline.
|
||||
Defaults to True.
|
||||
|
||||
takeout (``bool``, *optional*):
|
||||
Pass True to let the client use a takeout session instead of a normal one, implies *no_updates=True*.
|
||||
Useful for exporting Telegram data. Methods invoked inside a takeout session (such as get_chat_history,
|
||||
@ -242,12 +246,13 @@ class Client(Methods):
|
||||
plugins: dict = None,
|
||||
parse_mode: "enums.ParseMode" = enums.ParseMode.DEFAULT,
|
||||
no_updates: bool = None,
|
||||
skip_updates: bool = True,
|
||||
takeout: bool = None,
|
||||
sleep_threshold: int = Session.SLEEP_THRESHOLD,
|
||||
hide_password: bool = False,
|
||||
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
|
||||
storage_engine: Storage = None,
|
||||
init_connection_params: "raw.base.JSONValue" = None,
|
||||
init_connection_params: "raw.base.JSONValue" = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -274,6 +279,7 @@ class Client(Methods):
|
||||
self.plugins = plugins
|
||||
self.parse_mode = parse_mode
|
||||
self.no_updates = no_updates
|
||||
self.skip_updates = skip_updates
|
||||
self.takeout = takeout
|
||||
self.sleep_threshold = sleep_threshold
|
||||
self.hide_password = hide_password
|
||||
@ -583,6 +589,17 @@ class Client(Methods):
|
||||
pts = getattr(update, "pts", None)
|
||||
pts_count = getattr(update, "pts_count", None)
|
||||
|
||||
if pts:
|
||||
await self.storage.update_state(
|
||||
(
|
||||
utils.get_channel_id(channel_id) if channel_id else self.me.id,
|
||||
pts,
|
||||
None,
|
||||
updates.date,
|
||||
None
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(update, raw.types.UpdateChannelTooLong):
|
||||
log.info(update)
|
||||
|
||||
@ -613,6 +630,16 @@ class Client(Methods):
|
||||
|
||||
self.dispatcher.updates_queue.put_nowait((update, users, chats))
|
||||
elif isinstance(updates, (raw.types.UpdateShortMessage, raw.types.UpdateShortChatMessage)):
|
||||
await self.storage.update_state(
|
||||
(
|
||||
self.me.id,
|
||||
updates.pts,
|
||||
None,
|
||||
updates.date,
|
||||
None
|
||||
)
|
||||
)
|
||||
|
||||
diff = await self.invoke(
|
||||
raw.functions.updates.GetDifference(
|
||||
pts=updates.pts - updates.pts_count,
|
||||
|
@ -23,6 +23,7 @@ from collections import OrderedDict
|
||||
|
||||
import pyrogram
|
||||
from pyrogram import utils
|
||||
from pyrogram import raw
|
||||
from pyrogram.handlers import (
|
||||
CallbackQueryHandler, MessageHandler, EditedMessageHandler, DeletedMessagesHandler,
|
||||
UserStatusHandler, RawUpdateHandler, InlineQueryHandler, PollHandler,
|
||||
@ -166,6 +167,87 @@ class Dispatcher:
|
||||
|
||||
log.info("Started %s HandlerTasks", self.client.workers)
|
||||
|
||||
if not self.client.skip_updates:
|
||||
states = await self.client.storage.update_state()
|
||||
|
||||
if not states:
|
||||
log.info("No states found, skipping recovery.")
|
||||
return
|
||||
|
||||
message_updates_counter = 0
|
||||
other_updates_counter = 0
|
||||
|
||||
for state in states:
|
||||
id, local_pts, _, local_date, _ = state
|
||||
|
||||
prev_pts = 0
|
||||
|
||||
while True:
|
||||
diff = await self.client.invoke(
|
||||
raw.functions.updates.GetDifference(
|
||||
pts=local_pts,
|
||||
date=local_date,
|
||||
qts=0
|
||||
) if id == self.client.me.id else
|
||||
raw.functions.updates.GetChannelDifference(
|
||||
channel=await self.client.resolve_peer(id),
|
||||
filter=raw.types.ChannelMessagesFilterEmpty(),
|
||||
pts=local_pts,
|
||||
limit=10000
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(diff, (raw.types.updates.DifferenceEmpty, raw.types.updates.ChannelDifferenceEmpty)):
|
||||
break
|
||||
elif isinstance(diff, (raw.types.updates.DifferenceTooLong, raw.types.updates.ChannelDifferenceTooLong)):
|
||||
break
|
||||
elif isinstance(diff, raw.types.updates.ChannelDifference):
|
||||
local_pts = diff.pts
|
||||
elif isinstance(diff, raw.types.updates.Difference):
|
||||
local_pts = diff.state.pts
|
||||
elif isinstance(diff, raw.types.updates.DifferenceSlice):
|
||||
local_pts = diff.intermediate_state.pts
|
||||
local_date = diff.intermediate_state.date
|
||||
|
||||
if prev_pts == local_pts:
|
||||
break
|
||||
|
||||
prev_pts = local_pts
|
||||
|
||||
users = {i.id: i for i in diff.users}
|
||||
chats = {i.id: i for i in diff.chats}
|
||||
|
||||
for message in diff.new_messages:
|
||||
message_updates_counter += 1
|
||||
self.updates_queue.put_nowait(
|
||||
(
|
||||
raw.types.UpdateNewMessage(
|
||||
message=message,
|
||||
pts=local_pts,
|
||||
pts_count=-1
|
||||
) if id == self.client.me.id else
|
||||
raw.types.UpdateNewChannelMessage(
|
||||
message=message,
|
||||
pts=local_pts,
|
||||
pts_count=-1
|
||||
),
|
||||
users,
|
||||
chats
|
||||
)
|
||||
)
|
||||
|
||||
for update in diff.other_updates:
|
||||
other_updates_counter += 1
|
||||
self.updates_queue.put_nowait(
|
||||
(update, users, chats)
|
||||
)
|
||||
|
||||
if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
|
||||
break
|
||||
|
||||
await self.client.storage.update_state(None)
|
||||
log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
|
||||
|
||||
async def stop(self):
|
||||
if not self.client.no_updates:
|
||||
for i in range(self.client.workers):
|
||||
|
@ -25,8 +25,8 @@ from .sqlite_storage import SQLiteStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS usernames
|
||||
USERNAMES_SCHEMA = """
|
||||
CREATE TABLE usernames
|
||||
(
|
||||
id INTEGER,
|
||||
username TEXT,
|
||||
@ -36,6 +36,17 @@ CREATE TABLE IF NOT EXISTS usernames
|
||||
CREATE INDEX idx_usernames_username ON usernames (username);
|
||||
"""
|
||||
|
||||
UPDATE_STATE_SCHEMA = """
|
||||
CREATE TABLE update_state
|
||||
(
|
||||
id INTEGER PRIMARY KEY,
|
||||
pts INTEGER,
|
||||
qts INTEGER,
|
||||
date INTEGER,
|
||||
seq INTEGER
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
class FileStorage(SQLiteStorage):
|
||||
FILE_EXTENSION = ".session"
|
||||
@ -62,7 +73,13 @@ class FileStorage(SQLiteStorage):
|
||||
|
||||
if version == 3:
|
||||
with self.conn:
|
||||
self.conn.executescript(SCHEMA)
|
||||
self.conn.executescript(USERNAMES_SCHEMA)
|
||||
|
||||
version += 1
|
||||
|
||||
if version == 4:
|
||||
with self.conn:
|
||||
self.conn.executescript(UPDATE_STATE_SCHEMA)
|
||||
|
||||
version += 1
|
||||
|
||||
|
@ -54,6 +54,15 @@ CREATE TABLE usernames
|
||||
FOREIGN KEY (id) REFERENCES peers(id)
|
||||
);
|
||||
|
||||
CREATE TABLE update_state
|
||||
(
|
||||
id INTEGER PRIMARY KEY,
|
||||
pts INTEGER,
|
||||
qts INTEGER,
|
||||
date INTEGER,
|
||||
seq INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE version
|
||||
(
|
||||
number INTEGER PRIMARY KEY
|
||||
@ -96,7 +105,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
|
||||
|
||||
|
||||
class SQLiteStorage(Storage):
|
||||
VERSION = 4
|
||||
VERSION = 5
|
||||
USERNAME_TTL = 8 * 60 * 60
|
||||
|
||||
def __init__(self, name: str):
|
||||
@ -151,6 +160,24 @@ class SQLiteStorage(Storage):
|
||||
[(id, username) for username in usernames] if usernames else [(id, None)]
|
||||
)
|
||||
|
||||
async def update_state(self, value: Tuple[int, int, int, int, int] = object):
|
||||
if value == object:
|
||||
return self.conn.execute(
|
||||
"SELECT id, pts, qts, date, seq FROM update_state"
|
||||
).fetchall()
|
||||
else:
|
||||
with self.conn:
|
||||
if value is None:
|
||||
self.conn.execute(
|
||||
"DELETE FROM update_state"
|
||||
)
|
||||
else:
|
||||
self.conn.execute(
|
||||
"REPLACE INTO update_state (id, pts, qts, date, seq)"
|
||||
"VALUES (?, ?, ?, ?, ?)",
|
||||
value
|
||||
)
|
||||
|
||||
async def get_peer_by_id(self, peer_id: int):
|
||||
r = self.conn.execute(
|
||||
"SELECT id, access_hash, type FROM peers WHERE id = ?",
|
||||
|
@ -77,6 +77,21 @@ class Storage(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def update_state(self, update_state: Tuple[int, int, int, int, int] = object):
|
||||
"""Get or set the update state of the current session.
|
||||
|
||||
Parameters:
|
||||
update_state (``Tuple[int, int, int, int, int]``): A tuple containing the update state to set.
|
||||
Tuple must contain the following information:
|
||||
- ``int``: The id of the entity.
|
||||
- ``int``: The pts.
|
||||
- ``int``: The qts.
|
||||
- ``int``: The date.
|
||||
- ``int``: The seq.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_peer_by_id(self, peer_id: int):
|
||||
"""Retrieve a peer by its ID.
|
||||
|
Loading…
Reference in New Issue
Block a user