From 57433be99b42f326751d1acde2b177c2a8c4bd94 Mon Sep 17 00:00:00 2001 From: KurimuzonAkuma Date: Tue, 29 Oct 2024 13:16:59 +0300 Subject: [PATCH] Add update_usernames to storage --- pyrogram/client.py | 27 ++++++++------- pyrogram/storage/sqlite_storage.py | 54 ++++++++++++------------------ pyrogram/storage/storage.py | 26 +++++++++++--- 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/pyrogram/client.py b/pyrogram/client.py index 4c948e5c..beab6ae2 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -548,25 +548,26 @@ class Client(Methods): async def fetch_peers(self, peers: List[Union[raw.types.User, raw.types.Chat, raw.types.Channel]]) -> bool: is_min = False parsed_peers = [] + parsed_usernames = [] for peer in peers: if getattr(peer, "min", False): is_min = True continue - usernames = None + usernames = [] phone_number = None if isinstance(peer, raw.types.User): peer_id = peer.id access_hash = peer.access_hash - usernames = ( - [peer.username.lower()] if peer.username - else [username.username.lower() for username in peer.usernames] if peer.usernames - else None - ) phone_number = peer.phone peer_type = "bot" if peer.bot else "user" + + if peer.username: + usernames.append(peer.username.lower()) + elif peer.usernames: + usernames.extend(username.username.lower() for username in peer.usernames) elif isinstance(peer, (raw.types.Chat, raw.types.ChatForbidden)): peer_id = -peer.id access_hash = 0 @@ -574,12 +575,12 @@ class Client(Methods): elif isinstance(peer, raw.types.Channel): peer_id = utils.get_channel_id(peer.id) access_hash = peer.access_hash - usernames = ( - [peer.username.lower()] if peer.username - else [username.username.lower() for username in peer.usernames] if peer.usernames - else None - ) peer_type = "channel" if peer.broadcast else "supergroup" + + if peer.username: + usernames.append(peer.username.lower()) + elif peer.usernames: + usernames.extend(username.username.lower() for username in peer.usernames) elif isinstance(peer, raw.types.ChannelForbidden): peer_id = utils.get_channel_id(peer.id) access_hash = peer.access_hash @@ -587,9 +588,11 @@ class Client(Methods): else: continue - parsed_peers.append((peer_id, access_hash, peer_type, usernames, phone_number)) + parsed_peers.append((peer_id, access_hash, peer_type, phone_number)) + parsed_usernames.append((peer_id, usernames)) await self.storage.update_peers(parsed_peers) + await self.storage.update_usernames(parsed_usernames) return is_min diff --git a/pyrogram/storage/sqlite_storage.py b/pyrogram/storage/sqlite_storage.py index fa0e5f84..df25debd 100644 --- a/pyrogram/storage/sqlite_storage.py +++ b/pyrogram/storage/sqlite_storage.py @@ -141,34 +141,23 @@ class SQLiteStorage(Storage): async def delete(self): raise NotImplementedError - async def update_peers(self, peers: List[Tuple[int, int, str, List[str], str]]): - peers_data = [] - usernames_data = [] - ids_to_delete = [] - - for id, access_hash, type, usernames, phone_number in peers: - ids_to_delete.append((id,)) - peers_data.append((id, access_hash, type, phone_number)) - - if usernames: - usernames_data.extend([(id, username) for username in usernames]) - + async def update_peers(self, peers: List[Tuple[int, int, str, str]]): self.conn.executemany( "REPLACE INTO peers (id, access_hash, type, phone_number) VALUES (?, ?, ?, ?)", - peers_data + peers + ) + + async def update_usernames(self, usernames: List[Tuple[int, List[str]]]): + self.conn.executemany( + "DELETE FROM usernames WHERE id = ?", + [(id,) for id, _ in usernames] ) self.conn.executemany( - "DELETE FROM usernames WHERE id = ?", - ids_to_delete + "REPLACE INTO usernames (id, username) VALUES (?, ?)", + [(id, username) for id, usernames in usernames for username in usernames] ) - if usernames_data: - self.conn.executemany( - "REPLACE INTO usernames (id, username) VALUES (?, ?)", - usernames_data - ) - async def update_state(self, value: Tuple[int, int, int, int, int] = object): if value == object: return self.conn.execute( @@ -176,18 +165,17 @@ class SQLiteStorage(Storage): "ORDER BY date ASC" ).fetchall() else: - with self.conn: - if isinstance(value, int): - self.conn.execute( - "DELETE FROM update_state WHERE id = ?", - (value,) - ) - else: - self.conn.execute( - "REPLACE INTO update_state (id, pts, qts, date, seq)" - "VALUES (?, ?, ?, ?, ?)", - value - ) + if isinstance(value, int): + self.conn.execute( + "DELETE FROM update_state WHERE id = ?", + (value,) + ) + else: + self.conn.execute( + "REPLACE INTO update_state (id, pts, qts, date, seq)" + "VALUES (?, ?, ?, ?, ?)", + value + ) async def get_peer_by_id(self, peer_id: int): r = self.conn.execute( diff --git a/pyrogram/storage/storage.py b/pyrogram/storage/storage.py index fca7a170..15a3d5d4 100644 --- a/pyrogram/storage/storage.py +++ b/pyrogram/storage/storage.py @@ -57,32 +57,48 @@ class Storage(ABC): @abstractmethod async def delete(self): - """Deletes the storage.""" + """Deletes the storage file.""" raise NotImplementedError @abstractmethod - async def update_peers(self, peers: List[Tuple[int, int, str, List[str], str]]): + async def update_peers(self, peers: List[Tuple[int, int, str, str]]): """ Update the peers table with the provided information. Parameters: - peers (``List[Tuple[int, int, str, List[str], str]]``): A list of tuples containing the + peers (``List[Tuple[int, int, str, str]]``): + A list of tuples containing the information of the peers to be updated. Each tuple must contain the following information: - ``int``: The peer id. - ``int``: The peer access hash. - ``str``: The peer type (user, chat or channel). - - List of ``str``: The peer username (if any). - ``str``: The peer phone number (if any). """ raise NotImplementedError + @abstractmethod + async def update_usernames(self, usernames: List[Tuple[int, List[str]]]): + """ + Update the usernames table with the provided information. + + Parameters: + usernames (``List[Tuple[int, List[str]]]``): + A list of tuples containing the + information of the usernames to be updated. Each tuple must contain the following + information: + - ``int``: The peer id. + - List of ``str``: The peer username (if any). + """ + raise NotImplementedError + @abstractmethod async def update_state(self, update_state: Tuple[int, int, int, int, int] = object): """Get or set the update state of the current session. Parameters: - update_state (``Tuple[int, int, int, int, int]``): A tuple containing the update state to set. + update_state (``Tuple[int, int, int, int, int]``): + A tuple containing the update state to set. Tuple must contain the following information: - ``int``: The id of the entity. - ``int``: The pts.