diff --git a/pyrogram/client.py b/pyrogram/client.py index d885735a..816bb76a 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -32,7 +32,7 @@ from importlib import import_module from io import StringIO, BytesIO from mimetypes import MimeTypes from pathlib import Path -from typing import Union, List, Optional, Callable, AsyncGenerator, Type +from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple import pyrogram from pyrogram import __version__, __license__ @@ -45,7 +45,8 @@ from pyrogram.errors import ( SessionPasswordNeeded, VolumeLocNotFound, ChannelPrivate, BadRequest, AuthBytesInvalid, - FloodWait, FloodPremiumWait + FloodWait, FloodPremiumWait, + ChannelInvalid ) from pyrogram.handlers.handler import Handler from pyrogram.methods import Methods @@ -645,7 +646,8 @@ class Client(Methods): )] ), pts=pts - pts_count, - limit=pts + limit=pts, + force=False ) ) except ChannelPrivate: @@ -694,6 +696,92 @@ class Client(Methods): elif isinstance(updates, raw.types.UpdatesTooLong): log.info(updates) + async def recover_gaps(self) -> Tuple[int, int]: + states = await self.storage.update_state() + + message_updates_counter = 0 + other_updates_counter = 0 + + if not states: + log.info("No states found, skipping recovery.") + return (message_updates_counter, other_updates_counter) + + for state in states: + id, local_pts, _, local_date, _ = state + + prev_pts = 0 + + while True: + try: + diff = await self.invoke( + raw.functions.updates.GetChannelDifference( + channel=await self.resolve_peer(id), + filter=raw.types.ChannelMessagesFilterEmpty(), + pts=local_pts, + limit=10000, + force=False + ) if id < 0 else + raw.functions.updates.GetDifference( + pts=local_pts, + date=local_date, + qts=0 + ) + ) + except (ChannelPrivate, ChannelInvalid): + break + + if isinstance(diff, raw.types.updates.DifferenceEmpty): + break + elif isinstance(diff, raw.types.updates.DifferenceTooLong): + break + elif isinstance(diff, raw.types.updates.Difference): + local_pts = diff.state.pts + elif isinstance(diff, raw.types.updates.DifferenceSlice): + local_pts = diff.intermediate_state.pts + local_date = diff.intermediate_state.date + + if prev_pts == local_pts: + break + + prev_pts = local_pts + elif isinstance(diff, raw.types.updates.ChannelDifferenceEmpty): + break + elif isinstance(diff, raw.types.updates.ChannelDifferenceTooLong): + break + elif isinstance(diff, raw.types.updates.ChannelDifference): + local_pts = diff.pts + + users = {i.id: i for i in diff.users} + chats = {i.id: i for i in diff.chats} + + for message in diff.new_messages: + message_updates_counter += 1 + self.dispatcher.updates_queue.put_nowait( + ( + raw.types.UpdateNewMessage( + message=message, + pts=local_pts, + pts_count=-1 + ), + users, + chats + ) + ) + + for update in diff.other_updates: + other_updates_counter += 1 + self.dispatcher.updates_queue.put_nowait( + (update, users, chats) + ) + + if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)): + break + + await self.storage.update_state(id) + + log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter) + return (message_updates_counter, other_updates_counter) + async def load_session(self): await self.storage.open() diff --git a/pyrogram/dispatcher.py b/pyrogram/dispatcher.py index 1e6439c7..935ff0fa 100644 --- a/pyrogram/dispatcher.py +++ b/pyrogram/dispatcher.py @@ -180,93 +180,7 @@ class Dispatcher: log.info("Started %s HandlerTasks", self.client.workers) if not self.client.skip_updates: - states = await self.client.storage.update_state() - - if not states: - log.info("No states found, skipping recovery.") - return - - message_updates_counter = 0 - other_updates_counter = 0 - - for state in states: - id, local_pts, _, local_date, _ = state - - prev_pts = 0 - - while True: - try: - diff = await self.client.invoke( - raw.functions.updates.GetChannelDifference( - channel=await self.client.resolve_peer(id), - filter=raw.types.ChannelMessagesFilterEmpty(), - pts=local_pts, - limit=10000 - ) if id < 0 else - raw.functions.updates.GetDifference( - pts=local_pts, - date=local_date, - qts=0 - ) - ) - except (errors.ChannelPrivate, errors.ChannelInvalid): - break - - if isinstance(diff, raw.types.updates.DifferenceEmpty): - break - elif isinstance(diff, raw.types.updates.DifferenceTooLong): - break - elif isinstance(diff, raw.types.updates.Difference): - local_pts = diff.state.pts - elif isinstance(diff, raw.types.updates.DifferenceSlice): - local_pts = diff.intermediate_state.pts - local_date = diff.intermediate_state.date - - if prev_pts == local_pts: - break - - prev_pts = local_pts - elif isinstance(diff, raw.types.updates.ChannelDifferenceEmpty): - break - elif isinstance(diff, raw.types.updates.ChannelDifferenceTooLong): - break - elif isinstance(diff, raw.types.updates.ChannelDifference): - local_pts = diff.pts - - users = {i.id: i for i in diff.users} - chats = {i.id: i for i in diff.chats} - - for message in diff.new_messages: - message_updates_counter += 1 - self.updates_queue.put_nowait( - ( - raw.types.UpdateNewMessage( - message=message, - pts=local_pts, - pts_count=-1 - ) if id == self.client.me.id else - raw.types.UpdateNewChannelMessage( - message=message, - pts=local_pts, - pts_count=-1 - ), - users, - chats - ) - ) - - for update in diff.other_updates: - other_updates_counter += 1 - self.updates_queue.put_nowait( - (update, users, chats) - ) - - if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)): - break - - await self.client.storage.update_state(id) - - log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter) + await self.client.recover_gaps() async def stop(self): if not self.client.no_updates: