Extract recovering gaps into a separate method

This commit is contained in:
KurimuzonAkuma 2024-08-28 21:07:22 +03:00
parent 18ff0bae2f
commit 5c927d87a3
2 changed files with 92 additions and 90 deletions

View File

@ -32,7 +32,7 @@ from importlib import import_module
from io import StringIO, BytesIO from io import StringIO, BytesIO
from mimetypes import MimeTypes from mimetypes import MimeTypes
from pathlib import Path from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator, Type from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
import pyrogram import pyrogram
from pyrogram import __version__, __license__ from pyrogram import __version__, __license__
@ -45,7 +45,8 @@ from pyrogram.errors import (
SessionPasswordNeeded, SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate, VolumeLocNotFound, ChannelPrivate,
BadRequest, AuthBytesInvalid, BadRequest, AuthBytesInvalid,
FloodWait, FloodPremiumWait FloodWait, FloodPremiumWait,
ChannelInvalid
) )
from pyrogram.handlers.handler import Handler from pyrogram.handlers.handler import Handler
from pyrogram.methods import Methods from pyrogram.methods import Methods
@ -645,7 +646,8 @@ class Client(Methods):
)] )]
), ),
pts=pts - pts_count, pts=pts - pts_count,
limit=pts limit=pts,
force=False
) )
) )
except ChannelPrivate: except ChannelPrivate:
@ -694,6 +696,92 @@ class Client(Methods):
elif isinstance(updates, raw.types.UpdatesTooLong): elif isinstance(updates, raw.types.UpdatesTooLong):
log.info(updates) log.info(updates)
async def recover_gaps(self) -> Tuple[int, int]:
states = await self.storage.update_state()
message_updates_counter = 0
other_updates_counter = 0
if not states:
log.info("No states found, skipping recovery.")
return (message_updates_counter, other_updates_counter)
for state in states:
id, local_pts, _, local_date, _ = state
prev_pts = 0
while True:
try:
diff = await self.invoke(
raw.functions.updates.GetChannelDifference(
channel=await self.resolve_peer(id),
filter=raw.types.ChannelMessagesFilterEmpty(),
pts=local_pts,
limit=10000,
force=False
) if id < 0 else
raw.functions.updates.GetDifference(
pts=local_pts,
date=local_date,
qts=0
)
)
except (ChannelPrivate, ChannelInvalid):
break
if isinstance(diff, raw.types.updates.DifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.DifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.Difference):
local_pts = diff.state.pts
elif isinstance(diff, raw.types.updates.DifferenceSlice):
local_pts = diff.intermediate_state.pts
local_date = diff.intermediate_state.date
if prev_pts == local_pts:
break
prev_pts = local_pts
elif isinstance(diff, raw.types.updates.ChannelDifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.ChannelDifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.ChannelDifference):
local_pts = diff.pts
users = {i.id: i for i in diff.users}
chats = {i.id: i for i in diff.chats}
for message in diff.new_messages:
message_updates_counter += 1
self.dispatcher.updates_queue.put_nowait(
(
raw.types.UpdateNewMessage(
message=message,
pts=local_pts,
pts_count=-1
),
users,
chats
)
)
for update in diff.other_updates:
other_updates_counter += 1
self.dispatcher.updates_queue.put_nowait(
(update, users, chats)
)
if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
break
await self.storage.update_state(id)
log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
return (message_updates_counter, other_updates_counter)
async def load_session(self): async def load_session(self):
await self.storage.open() await self.storage.open()

View File

@ -180,93 +180,7 @@ class Dispatcher:
log.info("Started %s HandlerTasks", self.client.workers) log.info("Started %s HandlerTasks", self.client.workers)
if not self.client.skip_updates: if not self.client.skip_updates:
states = await self.client.storage.update_state() await self.client.recover_gaps()
if not states:
log.info("No states found, skipping recovery.")
return
message_updates_counter = 0
other_updates_counter = 0
for state in states:
id, local_pts, _, local_date, _ = state
prev_pts = 0
while True:
try:
diff = await self.client.invoke(
raw.functions.updates.GetChannelDifference(
channel=await self.client.resolve_peer(id),
filter=raw.types.ChannelMessagesFilterEmpty(),
pts=local_pts,
limit=10000
) if id < 0 else
raw.functions.updates.GetDifference(
pts=local_pts,
date=local_date,
qts=0
)
)
except (errors.ChannelPrivate, errors.ChannelInvalid):
break
if isinstance(diff, raw.types.updates.DifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.DifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.Difference):
local_pts = diff.state.pts
elif isinstance(diff, raw.types.updates.DifferenceSlice):
local_pts = diff.intermediate_state.pts
local_date = diff.intermediate_state.date
if prev_pts == local_pts:
break
prev_pts = local_pts
elif isinstance(diff, raw.types.updates.ChannelDifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.ChannelDifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.ChannelDifference):
local_pts = diff.pts
users = {i.id: i for i in diff.users}
chats = {i.id: i for i in diff.chats}
for message in diff.new_messages:
message_updates_counter += 1
self.updates_queue.put_nowait(
(
raw.types.UpdateNewMessage(
message=message,
pts=local_pts,
pts_count=-1
) if id == self.client.me.id else
raw.types.UpdateNewChannelMessage(
message=message,
pts=local_pts,
pts_count=-1
),
users,
chats
)
)
for update in diff.other_updates:
other_updates_counter += 1
self.updates_queue.put_nowait(
(update, users, chats)
)
if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
break
await self.client.storage.update_state(id)
log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
async def stop(self): async def stop(self):
if not self.client.no_updates: if not self.client.no_updates: