mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-16 20:59:29 +00:00
Unify peers cache
This commit is contained in:
parent
5dc33c6337
commit
260043d8ec
@ -324,8 +324,7 @@ class Client(Methods, BaseClient):
|
||||
now = time.time()
|
||||
|
||||
if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP:
|
||||
self.session_storage.peers_by_username.clear()
|
||||
self.session_storage.peers_by_phone.clear()
|
||||
self.session_storage.clear_cache()
|
||||
|
||||
self.get_initial_dialogs()
|
||||
self.get_contacts()
|
||||
@ -763,60 +762,7 @@ class Client(Methods, BaseClient):
|
||||
types.Chat, types.ChatForbidden,
|
||||
types.Channel, types.ChannelForbidden]]):
|
||||
for entity in entities:
|
||||
if isinstance(entity, types.User):
|
||||
user_id = entity.id
|
||||
|
||||
access_hash = entity.access_hash
|
||||
|
||||
if access_hash is None:
|
||||
continue
|
||||
|
||||
username = entity.username
|
||||
phone = entity.phone
|
||||
|
||||
input_peer = types.InputPeerUser(
|
||||
user_id=user_id,
|
||||
access_hash=access_hash
|
||||
)
|
||||
|
||||
self.session_storage.peers_by_id[user_id] = input_peer
|
||||
|
||||
if username is not None:
|
||||
self.session_storage.peers_by_username[username.lower()] = input_peer
|
||||
|
||||
if phone is not None:
|
||||
self.session_storage.peers_by_phone[phone] = input_peer
|
||||
|
||||
if isinstance(entity, (types.Chat, types.ChatForbidden)):
|
||||
chat_id = entity.id
|
||||
peer_id = -chat_id
|
||||
|
||||
input_peer = types.InputPeerChat(
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
self.session_storage.peers_by_id[peer_id] = input_peer
|
||||
|
||||
if isinstance(entity, (types.Channel, types.ChannelForbidden)):
|
||||
channel_id = entity.id
|
||||
peer_id = int("-100" + str(channel_id))
|
||||
|
||||
access_hash = entity.access_hash
|
||||
|
||||
if access_hash is None:
|
||||
continue
|
||||
|
||||
username = getattr(entity, "username", None)
|
||||
|
||||
input_peer = types.InputPeerChannel(
|
||||
channel_id=channel_id,
|
||||
access_hash=access_hash
|
||||
)
|
||||
|
||||
self.session_storage.peers_by_id[peer_id] = input_peer
|
||||
|
||||
if username is not None:
|
||||
self.session_storage.peers_by_username[username.lower()] = input_peer
|
||||
self.session_storage.cache_peer(entity)
|
||||
|
||||
def download_worker(self):
|
||||
name = threading.current_thread().name
|
||||
@ -1261,7 +1207,7 @@ class Client(Methods, BaseClient):
|
||||
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
|
||||
time.sleep(e.x)
|
||||
else:
|
||||
log.info("Total peers: {}".format(len(self.session_storage.peers_by_id)))
|
||||
log.info("Total peers: {}".format(self.session_storage.peers_count()))
|
||||
return r
|
||||
|
||||
def get_initial_dialogs(self):
|
||||
@ -1297,7 +1243,7 @@ class Client(Methods, BaseClient):
|
||||
``KeyError`` in case the peer doesn't exist in the internal database.
|
||||
"""
|
||||
try:
|
||||
return self.session_storage.peers_by_id[peer_id]
|
||||
return self.session_storage.get_peer_by_id(peer_id)
|
||||
except KeyError:
|
||||
if type(peer_id) is str:
|
||||
if peer_id in ("self", "me"):
|
||||
@ -1308,17 +1254,19 @@ class Client(Methods, BaseClient):
|
||||
try:
|
||||
int(peer_id)
|
||||
except ValueError:
|
||||
if peer_id not in self.session_storage.peers_by_username:
|
||||
try:
|
||||
self.session_storage.get_peer_by_username(peer_id)
|
||||
except KeyError:
|
||||
self.send(
|
||||
functions.contacts.ResolveUsername(
|
||||
username=peer_id
|
||||
)
|
||||
)
|
||||
|
||||
return self.session_storage.peers_by_username[peer_id]
|
||||
return self.session_storage.get_peer_by_username(peer_id)
|
||||
else:
|
||||
try:
|
||||
return self.session_storage.peers_by_phone[peer_id]
|
||||
return self.session_storage.get_peer_by_phone(peer_id)
|
||||
except KeyError:
|
||||
raise PeerIdInvalid
|
||||
|
||||
@ -1345,7 +1293,7 @@ class Client(Methods, BaseClient):
|
||||
)
|
||||
|
||||
try:
|
||||
return self.session_storage.peers_by_id[peer_id]
|
||||
return self.session_storage.get_peer_by_id(peer_id)
|
||||
except KeyError:
|
||||
raise PeerIdInvalid
|
||||
|
||||
|
@ -74,8 +74,8 @@ class BaseClient:
|
||||
self.rnd_id = MsgId
|
||||
self.channels_pts = {}
|
||||
|
||||
self.markdown = Markdown(self.session_storage.peers_by_id)
|
||||
self.html = HTML(self.session_storage.peers_by_id)
|
||||
self.markdown = Markdown(self.session_storage)
|
||||
self.html = HTML(self.session_storage)
|
||||
|
||||
self.session = None
|
||||
self.media_sessions = {}
|
||||
|
@ -44,5 +44,5 @@ class GetContacts(BaseClient):
|
||||
log.warning("get_contacts flood: waiting {} seconds".format(e.x))
|
||||
time.sleep(e.x)
|
||||
else:
|
||||
log.info("Total contacts: {}".format(len(self.session_storage.peers_by_phone)))
|
||||
log.info("Total contacts: {}".format(self.session_storage.contacts_count()))
|
||||
return [pyrogram.User._parse(self, user) for user in contacts.users]
|
||||
|
@ -17,9 +17,10 @@
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import abc
|
||||
from typing import Type
|
||||
from typing import Type, Union
|
||||
|
||||
import pyrogram
|
||||
from pyrogram.api import types
|
||||
|
||||
|
||||
class SessionDoesNotExist(Exception):
|
||||
@ -102,17 +103,41 @@ class SessionStorage(abc.ABC):
|
||||
def is_bot(self, val):
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def peers_by_id(self):
|
||||
def clear_cache(self):
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def peers_by_username(self):
|
||||
def cache_peer(self, entity: Union[types.User,
|
||||
types.Chat, types.ChatForbidden,
|
||||
types.Channel, types.ChannelForbidden]):
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def peers_by_phone(self):
|
||||
def get_peer_by_id(self, val: int):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_peer_by_username(self, val: str):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_peer_by_phone(self, val: str):
|
||||
...
|
||||
|
||||
def get_peer(self, peer_id: Union[int, str]):
|
||||
if isinstance(peer_id, int):
|
||||
return self.get_peer_by_id(peer_id)
|
||||
else:
|
||||
peer_id = peer_id.lstrip('+@')
|
||||
if peer_id.isdigit():
|
||||
return self.get_peer_by_phone(peer_id)
|
||||
return self.get_peer_by_username(peer_id)
|
||||
|
||||
@abc.abstractmethod
|
||||
def peers_count(self):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def contacts_count(self):
|
||||
...
|
||||
|
@ -58,19 +58,19 @@ class JsonSessionStorage(MemorySessionStorage):
|
||||
self._is_bot = s.get('is_bot', self._is_bot)
|
||||
|
||||
for k, v in s.get("peers_by_id", {}).items():
|
||||
self._peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
|
||||
self._peers_cache['i' + k] = utils.get_input_peer(int(k), v)
|
||||
|
||||
for k, v in s.get("peers_by_username", {}).items():
|
||||
peer = self._peers_by_id.get(v, None)
|
||||
|
||||
if peer:
|
||||
self._peers_by_username[k] = peer
|
||||
try:
|
||||
self._peers_cache['u' + k] = self.get_peer_by_id(v)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
for k, v in s.get("peers_by_phone", {}).items():
|
||||
peer = self._peers_by_id.get(v, None)
|
||||
|
||||
if peer:
|
||||
self._peers_by_phone[k] = peer
|
||||
try:
|
||||
self._peers_cache['p' + k] = self.get_peer_by_id(v)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def save(self, sync=False):
|
||||
file_path = self._get_file_name(self._session_name)
|
||||
@ -93,16 +93,19 @@ class JsonSessionStorage(MemorySessionStorage):
|
||||
'date': self._date,
|
||||
'is_bot': self._is_bot,
|
||||
'peers_by_id': {
|
||||
k: getattr(v, "access_hash", None)
|
||||
for k, v in self._peers_by_id.copy().items()
|
||||
k[1:]: getattr(v, "access_hash", None)
|
||||
for k, v in self._peers_cache.copy().items()
|
||||
if k[0] == 'i'
|
||||
},
|
||||
'peers_by_username': {
|
||||
k: utils.get_peer_id(v)
|
||||
for k, v in self._peers_by_username.copy().items()
|
||||
k[1:]: utils.get_peer_id(v)
|
||||
for k, v in self._peers_cache.copy().items()
|
||||
if k[0] == 'u'
|
||||
},
|
||||
'peers_by_phone': {
|
||||
k: utils.get_peer_id(v)
|
||||
for k, v in self._peers_by_phone.copy().items()
|
||||
k[1:]: utils.get_peer_id(v)
|
||||
for k, v in self._peers_cache.copy().items()
|
||||
if k[0] == 'p'
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import pyrogram
|
||||
from pyrogram.api import types
|
||||
from . import SessionStorage, SessionDoesNotExist
|
||||
|
||||
|
||||
@ -11,9 +12,7 @@ class MemorySessionStorage(SessionStorage):
|
||||
self._user_id = None
|
||||
self._date = 0
|
||||
self._is_bot = False
|
||||
self._peers_by_id = {}
|
||||
self._peers_by_username = {}
|
||||
self._peers_by_phone = {}
|
||||
self._peers_cache = {}
|
||||
|
||||
def load(self):
|
||||
raise SessionDoesNotExist()
|
||||
@ -72,14 +71,48 @@ class MemorySessionStorage(SessionStorage):
|
||||
def is_bot(self, val):
|
||||
self._is_bot = val
|
||||
|
||||
@property
|
||||
def peers_by_id(self):
|
||||
return self._peers_by_id
|
||||
def clear_cache(self):
|
||||
keys = list(filter(lambda k: k[0] in 'up', self._peers_cache.keys()))
|
||||
for key in keys:
|
||||
try:
|
||||
del self._peers_cache[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def peers_by_username(self):
|
||||
return self._peers_by_username
|
||||
def cache_peer(self, entity):
|
||||
if isinstance(entity, types.User):
|
||||
input_peer = types.InputPeerUser(
|
||||
user_id=entity.id,
|
||||
access_hash=entity.access_hash
|
||||
)
|
||||
self._peers_cache['i' + str(entity.id)] = input_peer
|
||||
if entity.username:
|
||||
self._peers_cache['u' + entity.username.lower()] = input_peer
|
||||
if entity.phone:
|
||||
self._peers_cache['p' + entity.phone] = input_peer
|
||||
elif isinstance(entity, (types.Chat, types.ChatForbidden)):
|
||||
self._peers_cache['i-' + str(entity.id)] = types.InputPeerChat(chat_id=entity.id)
|
||||
elif isinstance(entity, (types.Channel, types.ChannelForbidden)):
|
||||
input_peer = types.InputPeerChannel(
|
||||
channel_id=entity.id,
|
||||
access_hash=entity.access_hash
|
||||
)
|
||||
self._peers_cache['i-100' + str(entity.id)] = input_peer
|
||||
username = getattr(entity, "username", None)
|
||||
if username:
|
||||
self._peers_cache['u' + username.lower()] = input_peer
|
||||
|
||||
@property
|
||||
def peers_by_phone(self):
|
||||
return self._peers_by_phone
|
||||
def get_peer_by_id(self, val):
|
||||
return self._peers_cache['i' + str(val)]
|
||||
|
||||
def get_peer_by_username(self, val):
|
||||
return self._peers_cache['u' + val.lower()]
|
||||
|
||||
def get_peer_by_phone(self, val):
|
||||
return self._peers_cache['p' + val]
|
||||
|
||||
def peers_count(self):
|
||||
return len(list(filter(lambda k: k[0] == 'i', self._peers_cache.keys())))
|
||||
|
||||
def contacts_count(self):
|
||||
return len(list(filter(lambda k: k[0] == 'p', self._peers_cache.keys())))
|
||||
|
@ -29,14 +29,15 @@ from pyrogram.api.types import (
|
||||
InputMessageEntityMentionName as Mention,
|
||||
)
|
||||
from . import utils
|
||||
from ..session_storage import SessionStorage
|
||||
|
||||
|
||||
class HTML:
|
||||
HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)</\1>")
|
||||
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
|
||||
|
||||
def __init__(self, peers_by_id):
|
||||
self.peers_by_id = peers_by_id
|
||||
def __init__(self, session_storage: SessionStorage):
|
||||
self.session_storage = session_storage
|
||||
|
||||
def parse(self, message: str):
|
||||
entities = []
|
||||
@ -52,7 +53,10 @@ class HTML:
|
||||
|
||||
if mention:
|
||||
user_id = int(mention.group(1))
|
||||
input_user = self.peers_by_id.get(user_id, None)
|
||||
try:
|
||||
input_user = self.session_storage.get_peer_by_id(user_id)
|
||||
except KeyError:
|
||||
input_user = None
|
||||
|
||||
entity = (
|
||||
Mention(start, len(body), input_user)
|
||||
|
@ -29,6 +29,7 @@ from pyrogram.api.types import (
|
||||
InputMessageEntityMentionName as Mention
|
||||
)
|
||||
from . import utils
|
||||
from ..session_storage import SessionStorage
|
||||
|
||||
|
||||
class Markdown:
|
||||
@ -52,8 +53,8 @@ class Markdown:
|
||||
))
|
||||
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
|
||||
|
||||
def __init__(self, peers_by_id: dict):
|
||||
self.peers_by_id = peers_by_id
|
||||
def __init__(self, session_storage: SessionStorage):
|
||||
self.session_storage = session_storage
|
||||
|
||||
def parse(self, message: str):
|
||||
message = utils.add_surrogates(str(message)).strip()
|
||||
@ -69,7 +70,10 @@ class Markdown:
|
||||
|
||||
if mention:
|
||||
user_id = int(mention.group(1))
|
||||
input_user = self.peers_by_id.get(user_id, None)
|
||||
try:
|
||||
input_user = self.session_storage.get_peer_by_id(user_id)
|
||||
except KeyError:
|
||||
input_user = None
|
||||
|
||||
entity = (
|
||||
Mention(start, len(text), input_user)
|
||||
|
Loading…
Reference in New Issue
Block a user