Unify peers cache

This commit is contained in:
bakatrouble 2019-02-26 19:24:00 +03:00
parent 5dc33c6337
commit 260043d8ec
8 changed files with 122 additions and 105 deletions

View File

@ -324,8 +324,7 @@ class Client(Methods, BaseClient):
now = time.time()
if abs(now - self.session_storage.date) > Client.OFFLINE_SLEEP:
self.session_storage.peers_by_username.clear()
self.session_storage.peers_by_phone.clear()
self.session_storage.clear_cache()
self.get_initial_dialogs()
self.get_contacts()
@ -763,60 +762,7 @@ class Client(Methods, BaseClient):
types.Chat, types.ChatForbidden,
types.Channel, types.ChannelForbidden]]):
for entity in entities:
if isinstance(entity, types.User):
user_id = entity.id
access_hash = entity.access_hash
if access_hash is None:
continue
username = entity.username
phone = entity.phone
input_peer = types.InputPeerUser(
user_id=user_id,
access_hash=access_hash
)
self.session_storage.peers_by_id[user_id] = input_peer
if username is not None:
self.session_storage.peers_by_username[username.lower()] = input_peer
if phone is not None:
self.session_storage.peers_by_phone[phone] = input_peer
if isinstance(entity, (types.Chat, types.ChatForbidden)):
chat_id = entity.id
peer_id = -chat_id
input_peer = types.InputPeerChat(
chat_id=chat_id
)
self.session_storage.peers_by_id[peer_id] = input_peer
if isinstance(entity, (types.Channel, types.ChannelForbidden)):
channel_id = entity.id
peer_id = int("-100" + str(channel_id))
access_hash = entity.access_hash
if access_hash is None:
continue
username = getattr(entity, "username", None)
input_peer = types.InputPeerChannel(
channel_id=channel_id,
access_hash=access_hash
)
self.session_storage.peers_by_id[peer_id] = input_peer
if username is not None:
self.session_storage.peers_by_username[username.lower()] = input_peer
self.session_storage.cache_peer(entity)
def download_worker(self):
name = threading.current_thread().name
@ -1261,7 +1207,7 @@ class Client(Methods, BaseClient):
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
time.sleep(e.x)
else:
log.info("Total peers: {}".format(len(self.session_storage.peers_by_id)))
log.info("Total peers: {}".format(self.session_storage.peers_count()))
return r
def get_initial_dialogs(self):
@ -1297,7 +1243,7 @@ class Client(Methods, BaseClient):
``KeyError`` in case the peer doesn't exist in the internal database.
"""
try:
return self.session_storage.peers_by_id[peer_id]
return self.session_storage.get_peer_by_id(peer_id)
except KeyError:
if type(peer_id) is str:
if peer_id in ("self", "me"):
@ -1308,17 +1254,19 @@ class Client(Methods, BaseClient):
try:
int(peer_id)
except ValueError:
if peer_id not in self.session_storage.peers_by_username:
try:
self.session_storage.get_peer_by_username(peer_id)
except KeyError:
self.send(
functions.contacts.ResolveUsername(
username=peer_id
)
)
return self.session_storage.peers_by_username[peer_id]
return self.session_storage.get_peer_by_username(peer_id)
else:
try:
return self.session_storage.peers_by_phone[peer_id]
return self.session_storage.get_peer_by_phone(peer_id)
except KeyError:
raise PeerIdInvalid
@ -1345,7 +1293,7 @@ class Client(Methods, BaseClient):
)
try:
return self.session_storage.peers_by_id[peer_id]
return self.session_storage.get_peer_by_id(peer_id)
except KeyError:
raise PeerIdInvalid

View File

@ -74,8 +74,8 @@ class BaseClient:
self.rnd_id = MsgId
self.channels_pts = {}
self.markdown = Markdown(self.session_storage.peers_by_id)
self.html = HTML(self.session_storage.peers_by_id)
self.markdown = Markdown(self.session_storage)
self.html = HTML(self.session_storage)
self.session = None
self.media_sessions = {}

View File

@ -44,5 +44,5 @@ class GetContacts(BaseClient):
log.warning("get_contacts flood: waiting {} seconds".format(e.x))
time.sleep(e.x)
else:
log.info("Total contacts: {}".format(len(self.session_storage.peers_by_phone)))
log.info("Total contacts: {}".format(self.session_storage.contacts_count()))
return [pyrogram.User._parse(self, user) for user in contacts.users]

View File

@ -17,9 +17,10 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import abc
from typing import Type
from typing import Type, Union
import pyrogram
from pyrogram.api import types
class SessionDoesNotExist(Exception):
@ -102,17 +103,41 @@ class SessionStorage(abc.ABC):
def is_bot(self, val):
...
@property
@abc.abstractmethod
def peers_by_id(self):
def clear_cache(self):
...
@property
@abc.abstractmethod
def peers_by_username(self):
def cache_peer(self, entity: Union[types.User,
types.Chat, types.ChatForbidden,
types.Channel, types.ChannelForbidden]):
...
@property
@abc.abstractmethod
def peers_by_phone(self):
def get_peer_by_id(self, val: int):
...
@abc.abstractmethod
def get_peer_by_username(self, val: str):
...
@abc.abstractmethod
def get_peer_by_phone(self, val: str):
...
def get_peer(self, peer_id: Union[int, str]):
if isinstance(peer_id, int):
return self.get_peer_by_id(peer_id)
else:
peer_id = peer_id.lstrip('+@')
if peer_id.isdigit():
return self.get_peer_by_phone(peer_id)
return self.get_peer_by_username(peer_id)
@abc.abstractmethod
def peers_count(self):
...
@abc.abstractmethod
def contacts_count(self):
...

View File

@ -58,19 +58,19 @@ class JsonSessionStorage(MemorySessionStorage):
self._is_bot = s.get('is_bot', self._is_bot)
for k, v in s.get("peers_by_id", {}).items():
self._peers_by_id[int(k)] = utils.get_input_peer(int(k), v)
self._peers_cache['i' + k] = utils.get_input_peer(int(k), v)
for k, v in s.get("peers_by_username", {}).items():
peer = self._peers_by_id.get(v, None)
if peer:
self._peers_by_username[k] = peer
try:
self._peers_cache['u' + k] = self.get_peer_by_id(v)
except KeyError:
pass
for k, v in s.get("peers_by_phone", {}).items():
peer = self._peers_by_id.get(v, None)
if peer:
self._peers_by_phone[k] = peer
try:
self._peers_cache['p' + k] = self.get_peer_by_id(v)
except KeyError:
pass
def save(self, sync=False):
file_path = self._get_file_name(self._session_name)
@ -93,16 +93,19 @@ class JsonSessionStorage(MemorySessionStorage):
'date': self._date,
'is_bot': self._is_bot,
'peers_by_id': {
k: getattr(v, "access_hash", None)
for k, v in self._peers_by_id.copy().items()
k[1:]: getattr(v, "access_hash", None)
for k, v in self._peers_cache.copy().items()
if k[0] == 'i'
},
'peers_by_username': {
k: utils.get_peer_id(v)
for k, v in self._peers_by_username.copy().items()
k[1:]: utils.get_peer_id(v)
for k, v in self._peers_cache.copy().items()
if k[0] == 'u'
},
'peers_by_phone': {
k: utils.get_peer_id(v)
for k, v in self._peers_by_phone.copy().items()
k[1:]: utils.get_peer_id(v)
for k, v in self._peers_cache.copy().items()
if k[0] == 'p'
}
}

View File

@ -1,4 +1,5 @@
import pyrogram
from pyrogram.api import types
from . import SessionStorage, SessionDoesNotExist
@ -11,9 +12,7 @@ class MemorySessionStorage(SessionStorage):
self._user_id = None
self._date = 0
self._is_bot = False
self._peers_by_id = {}
self._peers_by_username = {}
self._peers_by_phone = {}
self._peers_cache = {}
def load(self):
raise SessionDoesNotExist()
@ -72,14 +71,48 @@ class MemorySessionStorage(SessionStorage):
def is_bot(self, val):
self._is_bot = val
@property
def peers_by_id(self):
return self._peers_by_id
def clear_cache(self):
keys = list(filter(lambda k: k[0] in 'up', self._peers_cache.keys()))
for key in keys:
try:
del self._peers_cache[key]
except KeyError:
pass
@property
def peers_by_username(self):
return self._peers_by_username
def cache_peer(self, entity):
if isinstance(entity, types.User):
input_peer = types.InputPeerUser(
user_id=entity.id,
access_hash=entity.access_hash
)
self._peers_cache['i' + str(entity.id)] = input_peer
if entity.username:
self._peers_cache['u' + entity.username.lower()] = input_peer
if entity.phone:
self._peers_cache['p' + entity.phone] = input_peer
elif isinstance(entity, (types.Chat, types.ChatForbidden)):
self._peers_cache['i-' + str(entity.id)] = types.InputPeerChat(chat_id=entity.id)
elif isinstance(entity, (types.Channel, types.ChannelForbidden)):
input_peer = types.InputPeerChannel(
channel_id=entity.id,
access_hash=entity.access_hash
)
self._peers_cache['i-100' + str(entity.id)] = input_peer
username = getattr(entity, "username", None)
if username:
self._peers_cache['u' + username.lower()] = input_peer
@property
def peers_by_phone(self):
return self._peers_by_phone
def get_peer_by_id(self, val):
return self._peers_cache['i' + str(val)]
def get_peer_by_username(self, val):
return self._peers_cache['u' + val.lower()]
def get_peer_by_phone(self, val):
return self._peers_cache['p' + val]
def peers_count(self):
return len(list(filter(lambda k: k[0] == 'i', self._peers_cache.keys())))
def contacts_count(self):
return len(list(filter(lambda k: k[0] == 'p', self._peers_cache.keys())))

View File

@ -29,14 +29,15 @@ from pyrogram.api.types import (
InputMessageEntityMentionName as Mention,
)
from . import utils
from ..session_storage import SessionStorage
class HTML:
HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)</\1>")
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
def __init__(self, peers_by_id):
self.peers_by_id = peers_by_id
def __init__(self, session_storage: SessionStorage):
self.session_storage = session_storage
def parse(self, message: str):
entities = []
@ -52,7 +53,10 @@ class HTML:
if mention:
user_id = int(mention.group(1))
input_user = self.peers_by_id.get(user_id, None)
try:
input_user = self.session_storage.get_peer_by_id(user_id)
except KeyError:
input_user = None
entity = (
Mention(start, len(body), input_user)

View File

@ -29,6 +29,7 @@ from pyrogram.api.types import (
InputMessageEntityMentionName as Mention
)
from . import utils
from ..session_storage import SessionStorage
class Markdown:
@ -52,8 +53,8 @@ class Markdown:
))
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
def __init__(self, peers_by_id: dict):
self.peers_by_id = peers_by_id
def __init__(self, session_storage: SessionStorage):
self.session_storage = session_storage
def parse(self, message: str):
message = utils.add_surrogates(str(message)).strip()
@ -69,7 +70,10 @@ class Markdown:
if mention:
user_id = int(mention.group(1))
input_user = self.peers_by_id.get(user_id, None)
try:
input_user = self.session_storage.get_peer_by_id(user_id)
except KeyError:
input_user = None
entity = (
Mention(start, len(text), input_user)