Extract recovering gaps into a separate method

This commit is contained in:
KurimuzonAkuma 2024-08-28 21:07:22 +03:00
parent 18ff0bae2f
commit 5c927d87a3
2 changed files with 92 additions and 90 deletions

View File

@ -32,7 +32,7 @@ from importlib import import_module
from io import StringIO, BytesIO
from mimetypes import MimeTypes
from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator, Type
from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
import pyrogram
from pyrogram import __version__, __license__
@ -45,7 +45,8 @@ from pyrogram.errors import (
SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate,
BadRequest, AuthBytesInvalid,
FloodWait, FloodPremiumWait
FloodWait, FloodPremiumWait,
ChannelInvalid
)
from pyrogram.handlers.handler import Handler
from pyrogram.methods import Methods
@ -645,7 +646,8 @@ class Client(Methods):
)]
),
pts=pts - pts_count,
limit=pts
limit=pts,
force=False
)
)
except ChannelPrivate:
@ -694,6 +696,92 @@ class Client(Methods):
elif isinstance(updates, raw.types.UpdatesTooLong):
log.info(updates)
async def recover_gaps(self) -> Tuple[int, int]:
states = await self.storage.update_state()
message_updates_counter = 0
other_updates_counter = 0
if not states:
log.info("No states found, skipping recovery.")
return (message_updates_counter, other_updates_counter)
for state in states:
id, local_pts, _, local_date, _ = state
prev_pts = 0
while True:
try:
diff = await self.invoke(
raw.functions.updates.GetChannelDifference(
channel=await self.resolve_peer(id),
filter=raw.types.ChannelMessagesFilterEmpty(),
pts=local_pts,
limit=10000,
force=False
) if id < 0 else
raw.functions.updates.GetDifference(
pts=local_pts,
date=local_date,
qts=0
)
)
except (ChannelPrivate, ChannelInvalid):
break
if isinstance(diff, raw.types.updates.DifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.DifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.Difference):
local_pts = diff.state.pts
elif isinstance(diff, raw.types.updates.DifferenceSlice):
local_pts = diff.intermediate_state.pts
local_date = diff.intermediate_state.date
if prev_pts == local_pts:
break
prev_pts = local_pts
elif isinstance(diff, raw.types.updates.ChannelDifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.ChannelDifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.ChannelDifference):
local_pts = diff.pts
users = {i.id: i for i in diff.users}
chats = {i.id: i for i in diff.chats}
for message in diff.new_messages:
message_updates_counter += 1
self.dispatcher.updates_queue.put_nowait(
(
raw.types.UpdateNewMessage(
message=message,
pts=local_pts,
pts_count=-1
),
users,
chats
)
)
for update in diff.other_updates:
other_updates_counter += 1
self.dispatcher.updates_queue.put_nowait(
(update, users, chats)
)
if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
break
await self.storage.update_state(id)
log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
return (message_updates_counter, other_updates_counter)
async def load_session(self):
await self.storage.open()

View File

@ -180,93 +180,7 @@ class Dispatcher:
log.info("Started %s HandlerTasks", self.client.workers)
if not self.client.skip_updates:
states = await self.client.storage.update_state()
if not states:
log.info("No states found, skipping recovery.")
return
message_updates_counter = 0
other_updates_counter = 0
for state in states:
id, local_pts, _, local_date, _ = state
prev_pts = 0
while True:
try:
diff = await self.client.invoke(
raw.functions.updates.GetChannelDifference(
channel=await self.client.resolve_peer(id),
filter=raw.types.ChannelMessagesFilterEmpty(),
pts=local_pts,
limit=10000
) if id < 0 else
raw.functions.updates.GetDifference(
pts=local_pts,
date=local_date,
qts=0
)
)
except (errors.ChannelPrivate, errors.ChannelInvalid):
break
if isinstance(diff, raw.types.updates.DifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.DifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.Difference):
local_pts = diff.state.pts
elif isinstance(diff, raw.types.updates.DifferenceSlice):
local_pts = diff.intermediate_state.pts
local_date = diff.intermediate_state.date
if prev_pts == local_pts:
break
prev_pts = local_pts
elif isinstance(diff, raw.types.updates.ChannelDifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.ChannelDifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.ChannelDifference):
local_pts = diff.pts
users = {i.id: i for i in diff.users}
chats = {i.id: i for i in diff.chats}
for message in diff.new_messages:
message_updates_counter += 1
self.updates_queue.put_nowait(
(
raw.types.UpdateNewMessage(
message=message,
pts=local_pts,
pts_count=-1
) if id == self.client.me.id else
raw.types.UpdateNewChannelMessage(
message=message,
pts=local_pts,
pts_count=-1
),
users,
chats
)
)
for update in diff.other_updates:
other_updates_counter += 1
self.updates_queue.put_nowait(
(update, users, chats)
)
if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
break
await self.client.storage.update_state(id)
log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
await self.client.recover_gaps()
async def stop(self):
if not self.client.no_updates: