Update Auth and Session to accommodate Storage Engines

This commit is contained in:
Dan 2019-06-19 16:01:23 +02:00
parent d472d06c48
commit 682591ea8f
13 changed files with 36 additions and 588 deletions

View File

@ -26,14 +26,13 @@ import shutil
import tempfile
import threading
import time
import warnings
from configparser import ConfigParser
from hashlib import sha256, md5
from importlib import import_module
from pathlib import Path
from signal import signal, SIGINT, SIGTERM, SIGABRT
from threading import Thread
from typing import Union, List, Type
from typing import Union, List
from pyrogram.api import functions, types
from pyrogram.api.core import TLObject
@ -205,24 +204,9 @@ class Client(Methods, BaseClient):
no_updates: bool = None,
takeout: bool = None
):
super().__init__()
if isinstance(session_name, str):
if session_name == ':memory:':
session_storage = MemorySessionStorage(self)
elif session_name.startswith(':'):
session_storage = StringSessionStorage(self, session_name)
else:
session_storage = SQLiteSessionStorage(self, session_name)
elif isinstance(session_name, SessionStorage):
session_storage = session_name
else:
raise RuntimeError('Wrong session_name passed, expected str or SessionConfig subclass')
super().__init__(session_storage)
super().__init__(session_storage)
self.session_name = str(session_name) # TODO: build correct session name
self.session_name = session_name
self.api_id = int(api_id) if api_id else None
self.api_hash = api_hash
self.app_version = app_version
@ -232,7 +216,7 @@ class Client(Methods, BaseClient):
self.ipv6 = ipv6
# TODO: Make code consistent, use underscore for private/protected fields
self._proxy = proxy
self.session_storage.test_mode = test_mode
self.test_mode = test_mode
self.bot_token = bot_token
self.phone_number = phone_number
self.phone_code = phone_code

View File

@ -27,7 +27,6 @@ from threading import Lock
from pyrogram import __version__
from ..style import Markdown, HTML
from ...session.internals import MsgId
from ..session_storage import SessionStorage
class BaseClient:

View File

@ -1,139 +0,0 @@
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import abc
from typing import Type, Union
import pyrogram
from pyrogram.api import types
class SessionDoesNotExist(Exception):
pass
class SessionStorage(abc.ABC):
def __init__(self, client: 'pyrogram.client.BaseClient'):
self._client = client
@abc.abstractmethod
def load(self):
...
@abc.abstractmethod
def save(self, sync=False):
...
@property
@abc.abstractmethod
def dc_id(self):
...
@dc_id.setter
@abc.abstractmethod
def dc_id(self, val):
...
@property
@abc.abstractmethod
def test_mode(self):
...
@test_mode.setter
@abc.abstractmethod
def test_mode(self, val):
...
@property
@abc.abstractmethod
def auth_key(self):
...
@auth_key.setter
@abc.abstractmethod
def auth_key(self, val):
...
@property
@abc.abstractmethod
def user_id(self):
...
@user_id.setter
@abc.abstractmethod
def user_id(self, val):
...
@property
@abc.abstractmethod
def date(self):
...
@date.setter
@abc.abstractmethod
def date(self, val):
...
@property
@abc.abstractmethod
def is_bot(self):
...
@is_bot.setter
@abc.abstractmethod
def is_bot(self, val):
...
@abc.abstractmethod
def clear_cache(self):
...
@abc.abstractmethod
def cache_peer(self, entity: Union[types.User,
types.Chat, types.ChatForbidden,
types.Channel, types.ChannelForbidden]):
...
@abc.abstractmethod
def get_peer_by_id(self, val: int):
...
@abc.abstractmethod
def get_peer_by_username(self, val: str):
...
@abc.abstractmethod
def get_peer_by_phone(self, val: str):
...
def get_peer(self, peer_id: Union[int, str]):
if isinstance(peer_id, int):
return self.get_peer_by_id(peer_id)
else:
peer_id = peer_id.lstrip('+@')
if peer_id.isdigit():
return self.get_peer_by_phone(peer_id)
return self.get_peer_by_username(peer_id)
@abc.abstractmethod
def peers_count(self):
...
@abc.abstractmethod
def contacts_count(self):
...

View File

@ -1,63 +0,0 @@
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import base64
import json
import logging
import os
import shutil
import pyrogram
from ..ext import utils
from . import MemorySessionStorage, SessionDoesNotExist
log = logging.getLogger(__name__)
EXTENSION = '.session'
class JsonSessionStorage(MemorySessionStorage):
def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_name: str):
super(JsonSessionStorage, self).__init__(client)
self._session_name = session_name
def _get_file_name(self, name: str):
if not name.endswith(EXTENSION):
name += EXTENSION
return os.path.join(self._client.workdir, name)
def load(self):
file_path = self._get_file_name(self._session_name)
log.info('Loading JSON session from {}'.format(file_path))
try:
with open(file_path, encoding='utf-8') as f:
s = json.load(f)
except FileNotFoundError:
raise SessionDoesNotExist()
self._dc_id = s["dc_id"]
self._test_mode = s["test_mode"]
self._auth_key = base64.b64decode("".join(s["auth_key"])) # join split key
self._user_id = s["user_id"]
self._date = s.get("date", 0)
self._is_bot = s.get('is_bot', self._is_bot)
def save(self, sync=False):
pass

View File

@ -1,115 +0,0 @@
import pyrogram
from pyrogram.api import types
from . import SessionStorage, SessionDoesNotExist
class MemorySessionStorage(SessionStorage):
def __init__(self, client: 'pyrogram.client.ext.BaseClient'):
super(MemorySessionStorage, self).__init__(client)
self._dc_id = 1
self._test_mode = None
self._auth_key = None
self._user_id = None
self._date = 0
self._is_bot = False
self._peers_cache = {}
def load(self):
raise SessionDoesNotExist()
def save(self, sync=False):
pass
@property
def dc_id(self):
return self._dc_id
@dc_id.setter
def dc_id(self, val):
self._dc_id = val
@property
def test_mode(self):
return self._test_mode
@test_mode.setter
def test_mode(self, val):
self._test_mode = val
@property
def auth_key(self):
return self._auth_key
@auth_key.setter
def auth_key(self, val):
self._auth_key = val
@property
def user_id(self):
return self._user_id
@user_id.setter
def user_id(self, val):
self._user_id = val
@property
def date(self):
return self._date
@date.setter
def date(self, val):
self._date = val
@property
def is_bot(self):
return self._is_bot
@is_bot.setter
def is_bot(self, val):
self._is_bot = val
def clear_cache(self):
keys = list(filter(lambda k: k[0] in 'up', self._peers_cache.keys()))
for key in keys:
try:
del self._peers_cache[key]
except KeyError:
pass
def cache_peer(self, entity):
if isinstance(entity, types.User):
input_peer = types.InputPeerUser(
user_id=entity.id,
access_hash=entity.access_hash
)
self._peers_cache['i' + str(entity.id)] = input_peer
if entity.username:
self._peers_cache['u' + entity.username.lower()] = input_peer
if entity.phone:
self._peers_cache['p' + entity.phone] = input_peer
elif isinstance(entity, (types.Chat, types.ChatForbidden)):
self._peers_cache['i-' + str(entity.id)] = types.InputPeerChat(chat_id=entity.id)
elif isinstance(entity, (types.Channel, types.ChannelForbidden)):
input_peer = types.InputPeerChannel(
channel_id=entity.id,
access_hash=entity.access_hash
)
self._peers_cache['i-100' + str(entity.id)] = input_peer
username = getattr(entity, "username", None)
if username:
self._peers_cache['u' + username.lower()] = input_peer
def get_peer_by_id(self, val):
return self._peers_cache['i' + str(val)]
def get_peer_by_username(self, val):
return self._peers_cache['u' + val.lower()]
def get_peer_by_phone(self, val):
return self._peers_cache['p' + val]
def peers_count(self):
return len(list(filter(lambda k: k[0] == 'i', self._peers_cache.keys())))
def contacts_count(self):
return len(list(filter(lambda k: k[0] == 'p', self._peers_cache.keys())))

View File

@ -1,24 +0,0 @@
create table sessions (
dc_id integer primary key,
test_mode integer,
auth_key blob,
user_id integer,
date integer,
is_bot integer
);
create table peers_cache (
id integer primary key,
hash integer,
username text,
phone integer
);
create table migrations (
name text primary key
);
create index username_idx on peers_cache(username);
create index phone_idx on peers_cache(phone);
insert into migrations (name) values ('0001');

View File

@ -1,153 +0,0 @@
# Pyrogram - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-2019 Dan Tès <https://github.com/delivrance>
#
# This file is part of Pyrogram.
#
# Pyrogram is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrogram is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import logging
import os
import shutil
import sqlite3
from threading import Lock
import pyrogram
from ....api import types
from ...ext import utils
from .. import MemorySessionStorage, SessionDoesNotExist, JsonSessionStorage
log = logging.getLogger(__name__)
EXTENSION = '.session'
MIGRATIONS = ['0001']
class SQLiteSessionStorage(MemorySessionStorage):
def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_name: str):
super(SQLiteSessionStorage, self).__init__(client)
self._session_name = session_name
self._conn = None # type: sqlite3.Connection
self._lock = Lock()
def _get_file_name(self, name: str):
if not name.endswith(EXTENSION):
name += EXTENSION
return os.path.join(self._client.workdir, name)
def _apply_migrations(self, new_db=False):
self._conn.execute('PRAGMA read_uncommitted = true')
migrations = MIGRATIONS.copy()
if not new_db:
cursor = self._conn.cursor()
cursor.execute('select name from migrations')
for row in cursor.fetchone():
migrations.remove(row)
for name in migrations:
with open(os.path.join(os.path.dirname(__file__), '{}.sql'.format(name))) as script:
self._conn.executescript(script.read())
def _migrate_from_json(self):
jss = JsonSessionStorage(self._client, self._session_name)
jss.load()
file_path = self._get_file_name(self._session_name)
self._conn = sqlite3.connect(file_path + '.tmp')
self._apply_migrations(new_db=True)
self._dc_id, self._test_mode, self._auth_key, self._user_id, self._date, self._is_bot = \
jss.dc_id, jss.test_mode, jss.auth_key, jss.user_id, jss.date, jss.is_bot
self.save()
self._conn.close()
shutil.move(file_path + '.tmp', file_path)
log.warning('Session was migrated from JSON, loading...')
self.load()
def load(self):
file_path = self._get_file_name(self._session_name)
log.info('Loading SQLite session from {}'.format(file_path))
if os.path.isfile(file_path):
try:
self._conn = sqlite3.connect(file_path, isolation_level='EXCLUSIVE', check_same_thread=False)
self._apply_migrations()
except sqlite3.DatabaseError:
log.warning('Trying to migrate session from JSON...')
self._migrate_from_json()
return
else:
self._conn = sqlite3.connect(file_path, isolation_level='EXCLUSIVE', check_same_thread=False)
self._apply_migrations(new_db=True)
cursor = self._conn.cursor()
cursor.execute('select dc_id, test_mode, auth_key, user_id, "date", is_bot from sessions')
row = cursor.fetchone()
if not row:
raise SessionDoesNotExist()
self._dc_id = row[0]
self._test_mode = bool(row[1])
self._auth_key = row[2]
self._user_id = row[3]
self._date = row[4]
self._is_bot = bool(row[5])
def cache_peer(self, entity):
peer_id = username = phone = access_hash = None
if isinstance(entity, types.User):
peer_id = entity.id
username = entity.username.lower() if entity.username else None
phone = entity.phone or None
access_hash = entity.access_hash
elif isinstance(entity, (types.Chat, types.ChatForbidden)):
peer_id = -entity.id
elif isinstance(entity, (types.Channel, types.ChannelForbidden)):
peer_id = int('-100' + str(entity.id))
username = entity.username.lower() if hasattr(entity, 'username') and entity.username else None
access_hash = entity.access_hash
with self._lock:
self._conn.execute('insert or replace into peers_cache values (?, ?, ?, ?)',
(peer_id, access_hash, username, phone))
def get_peer_by_id(self, val):
cursor = self._conn.cursor()
cursor.execute('select id, hash from peers_cache where id = ?', (val,))
row = cursor.fetchone()
if not row:
raise KeyError(val)
return utils.get_input_peer(row[0], row[1])
def get_peer_by_username(self, val):
cursor = self._conn.cursor()
cursor.execute('select id, hash from peers_cache where username = ?', (val,))
row = cursor.fetchone()
if not row:
raise KeyError(val)
return utils.get_input_peer(row[0], row[1])
def get_peer_by_phone(self, val):
cursor = self._conn.cursor()
cursor.execute('select id, hash from peers_cache where phone = ?', (val,))
row = cursor.fetchone()
if not row:
raise KeyError(val)
return utils.get_input_peer(row[0], row[1])
def save(self, sync=False):
log.info('Committing SQLite session')
with self._lock:
self._conn.execute('delete from sessions')
self._conn.execute('insert into sessions values (?, ?, ?, ?, ?, ?)',
(self._dc_id, self._test_mode, self._auth_key, self._user_id, self._date, self._is_bot))
self._conn.commit()

View File

@ -1,46 +0,0 @@
import base64
import binascii
import struct
import pyrogram
from . import MemorySessionStorage, SessionDoesNotExist
class StringSessionStorage(MemorySessionStorage):
"""
Packs session data as following (forcing little-endian byte order):
Char dc_id (1 byte, unsigned)
Boolean test_mode (1 byte)
Long long user_id (8 bytes, signed)
Boolean is_bot (1 byte)
Bytes auth_key (256 bytes)
Uses Base64 encoding for printable representation
"""
PACK_FORMAT = '<B?q?256s'
def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_string: str):
super(StringSessionStorage, self).__init__(client)
self._session_string = session_string
def _unpack(self, data):
return struct.unpack(self.PACK_FORMAT, data)
def _pack(self):
return struct.pack(self.PACK_FORMAT, self._dc_id, self._test_mode, self._user_id, self._is_bot, self._auth_key)
def load(self):
try:
session_string = self._session_string[1:]
session_string += '=' * (4 - len(session_string) % 4) # restore padding
decoded = base64.b64decode(session_string, b'-_')
self._dc_id, self._test_mode, self._user_id, self._is_bot, self._auth_key = self._unpack(decoded)
except (struct.error, binascii.Error):
raise SessionDoesNotExist()
def save(self, sync=False):
if not sync:
packed = self._pack()
encoded = ':' + base64.b64encode(packed, b'-_').decode('latin-1').rstrip('=')
split = '\n'.join(['"{}"'.format(encoded[i: i + 50]) for i in range(0, len(encoded), 50)])
print('Created session string:\n{}'.format(split))

View File

@ -16,8 +16,6 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
from .abstract import SessionStorage, SessionDoesNotExist
from .memory import MemorySessionStorage
from .json import JsonSessionStorage
from .string import StringSessionStorage
from .sqlite import SQLiteSessionStorage
from .memory_storage import MemoryStorage
from .file_storage import FileStorage
from .storage import Storage

View File

@ -31,16 +31,14 @@ from pyrogram.api.types import (
)
from pyrogram.errors import PeerIdInvalid
from . import utils
from ..session_storage import SessionStorage
class HTML:
HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)</\1>")
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
def __init__(self, session_storage: SessionStorage, client: "pyrogram.BaseClient" = None):
def __init__(self, client: "pyrogram.BaseClient" = None):
self.client = client
self.session_storage = session_storage
def parse(self, message: str):
entities = []
@ -56,9 +54,10 @@ class HTML:
if mention:
user_id = int(mention.group(1))
try:
input_user = self.session_storage.get_peer_by_id(user_id)
except KeyError:
input_user = self.client.resolve_peer(user_id)
except PeerIdInvalid:
input_user = None
entity = (

View File

@ -31,7 +31,6 @@ from pyrogram.api.types import (
)
from pyrogram.errors import PeerIdInvalid
from . import utils
from ..session_storage import SessionStorage
class Markdown:
@ -55,9 +54,8 @@ class Markdown:
))
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
def __init__(self, session_storage: SessionStorage, client: "pyrogram.BaseClient" = None):
def __init__(self, client: "pyrogram.BaseClient" = None):
self.client = client
self.session_storage = session_storage
def parse(self, message: str):
message = utils.add_surrogates(str(message or "")).strip()
@ -73,9 +71,10 @@ class Markdown:
if mention:
user_id = int(mention.group(1))
try:
input_user = self.session_storage.get_peer_by_id(user_id)
except KeyError:
input_user = self.client.resolve_peer(user_id)
except PeerIdInvalid:
input_user = None
entity = (

View File

@ -22,10 +22,12 @@ from hashlib import sha1
from io import BytesIO
from os import urandom
import pyrogram
from pyrogram.api import functions, types
from pyrogram.api.core import TLObject, Long, Int
from pyrogram.connection import Connection
from pyrogram.crypto import AES, RSA, Prime
from .internals import MsgId
log = logging.getLogger(__name__)
@ -34,11 +36,11 @@ log = logging.getLogger(__name__)
class Auth:
MAX_RETRIES = 5
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict):
def __init__(self, client: "pyrogram.Client", dc_id: int):
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = ipv6
self.proxy = proxy
self.test_mode = client.storage.test_mode
self.ipv6 = client.ipv6
self.proxy = client.proxy
self.connection = None

View File

@ -34,6 +34,7 @@ from pyrogram.api.core import Message, TLObject, MsgContainer, Long, FutureSalt,
from pyrogram.connection import Connection
from pyrogram.crypto import AES, KDF
from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated
from .internals import MsgId, MsgFactory
log = logging.getLogger(__name__)
@ -70,12 +71,14 @@ class Session:
64: "[64] invalid container"
}
def __init__(self,
client: pyrogram,
dc_id: int,
auth_key: bytes,
is_media: bool = False,
is_cdn: bool = False):
def __init__(
self,
client: pyrogram,
dc_id: int,
auth_key: bytes,
is_media: bool = False,
is_cdn: bool = False
):
if not Session.notice_displayed:
print("Pyrogram v{}, {}".format(__version__, __copyright__))
print("Licensed under the terms of the " + __license__, end="\n\n")
@ -113,8 +116,12 @@ class Session:
def start(self):
while True:
self.connection = Connection(self.dc_id, self.client.session_storage.test_mode,
self.client.ipv6, self.client.proxy)
self.connection = Connection(
self.dc_id,
self.client.storage.test_mode,
self.client.ipv6,
self.client.proxy
)
try:
self.connection.connect()