From 682591ea8fdb3f33c2970690b0aafd67c8823c29 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Wed, 19 Jun 2019 16:01:23 +0200 Subject: [PATCH] Update Auth and Session to accommodate Storage Engines --- pyrogram/client/client.py | 24 +-- pyrogram/client/ext/base_client.py | 1 - pyrogram/client/session_storage/abstract.py | 139 ---------------- pyrogram/client/session_storage/json.py | 63 -------- pyrogram/client/session_storage/memory.py | 115 ------------- .../client/session_storage/sqlite/0001.sql | 24 --- .../client/session_storage/sqlite/__init__.py | 153 ------------------ pyrogram/client/session_storage/string.py | 46 ------ .../{session_storage => storage}/__init__.py | 8 +- pyrogram/client/style/html.py | 9 +- pyrogram/client/style/markdown.py | 9 +- pyrogram/session/auth.py | 10 +- pyrogram/session/session.py | 23 ++- 13 files changed, 36 insertions(+), 588 deletions(-) delete mode 100644 pyrogram/client/session_storage/abstract.py delete mode 100644 pyrogram/client/session_storage/json.py delete mode 100644 pyrogram/client/session_storage/memory.py delete mode 100644 pyrogram/client/session_storage/sqlite/0001.sql delete mode 100644 pyrogram/client/session_storage/sqlite/__init__.py delete mode 100644 pyrogram/client/session_storage/string.py rename pyrogram/client/{session_storage => storage}/__init__.py (78%) diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 1aa436b5..885c3334 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -26,14 +26,13 @@ import shutil import tempfile import threading import time -import warnings from configparser import ConfigParser from hashlib import sha256, md5 from importlib import import_module from pathlib import Path from signal import signal, SIGINT, SIGTERM, SIGABRT from threading import Thread -from typing import Union, List, Type +from typing import Union, List from pyrogram.api import functions, types from pyrogram.api.core import TLObject @@ -205,24 +204,9 @@ class Client(Methods, BaseClient): no_updates: bool = None, takeout: bool = None ): + super().__init__() - if isinstance(session_name, str): - if session_name == ':memory:': - session_storage = MemorySessionStorage(self) - elif session_name.startswith(':'): - session_storage = StringSessionStorage(self, session_name) - else: - session_storage = SQLiteSessionStorage(self, session_name) - elif isinstance(session_name, SessionStorage): - session_storage = session_name - else: - raise RuntimeError('Wrong session_name passed, expected str or SessionConfig subclass') - - super().__init__(session_storage) - - super().__init__(session_storage) - - self.session_name = str(session_name) # TODO: build correct session name + self.session_name = session_name self.api_id = int(api_id) if api_id else None self.api_hash = api_hash self.app_version = app_version @@ -232,7 +216,7 @@ class Client(Methods, BaseClient): self.ipv6 = ipv6 # TODO: Make code consistent, use underscore for private/protected fields self._proxy = proxy - self.session_storage.test_mode = test_mode + self.test_mode = test_mode self.bot_token = bot_token self.phone_number = phone_number self.phone_code = phone_code diff --git a/pyrogram/client/ext/base_client.py b/pyrogram/client/ext/base_client.py index aaf87823..9276b0eb 100644 --- a/pyrogram/client/ext/base_client.py +++ b/pyrogram/client/ext/base_client.py @@ -27,7 +27,6 @@ from threading import Lock from pyrogram import __version__ from ..style import Markdown, HTML from ...session.internals import MsgId -from ..session_storage import SessionStorage class BaseClient: diff --git a/pyrogram/client/session_storage/abstract.py b/pyrogram/client/session_storage/abstract.py deleted file mode 100644 index 134d5c8c..00000000 --- a/pyrogram/client/session_storage/abstract.py +++ /dev/null @@ -1,139 +0,0 @@ -# Pyrogram - Telegram MTProto API Client Library for Python -# Copyright (C) 2017-2019 Dan Tès -# -# This file is part of Pyrogram. -# -# Pyrogram is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published -# by the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Pyrogram is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with Pyrogram. If not, see . - -import abc -from typing import Type, Union - -import pyrogram -from pyrogram.api import types - - -class SessionDoesNotExist(Exception): - pass - - -class SessionStorage(abc.ABC): - def __init__(self, client: 'pyrogram.client.BaseClient'): - self._client = client - - @abc.abstractmethod - def load(self): - ... - - @abc.abstractmethod - def save(self, sync=False): - ... - - @property - @abc.abstractmethod - def dc_id(self): - ... - - @dc_id.setter - @abc.abstractmethod - def dc_id(self, val): - ... - - @property - @abc.abstractmethod - def test_mode(self): - ... - - @test_mode.setter - @abc.abstractmethod - def test_mode(self, val): - ... - - @property - @abc.abstractmethod - def auth_key(self): - ... - - @auth_key.setter - @abc.abstractmethod - def auth_key(self, val): - ... - - @property - @abc.abstractmethod - def user_id(self): - ... - - @user_id.setter - @abc.abstractmethod - def user_id(self, val): - ... - - @property - @abc.abstractmethod - def date(self): - ... - - @date.setter - @abc.abstractmethod - def date(self, val): - ... - - @property - @abc.abstractmethod - def is_bot(self): - ... - - @is_bot.setter - @abc.abstractmethod - def is_bot(self, val): - ... - - @abc.abstractmethod - def clear_cache(self): - ... - - @abc.abstractmethod - def cache_peer(self, entity: Union[types.User, - types.Chat, types.ChatForbidden, - types.Channel, types.ChannelForbidden]): - ... - - @abc.abstractmethod - def get_peer_by_id(self, val: int): - ... - - @abc.abstractmethod - def get_peer_by_username(self, val: str): - ... - - @abc.abstractmethod - def get_peer_by_phone(self, val: str): - ... - - def get_peer(self, peer_id: Union[int, str]): - if isinstance(peer_id, int): - return self.get_peer_by_id(peer_id) - else: - peer_id = peer_id.lstrip('+@') - if peer_id.isdigit(): - return self.get_peer_by_phone(peer_id) - return self.get_peer_by_username(peer_id) - - @abc.abstractmethod - def peers_count(self): - ... - - @abc.abstractmethod - def contacts_count(self): - ... diff --git a/pyrogram/client/session_storage/json.py b/pyrogram/client/session_storage/json.py deleted file mode 100644 index 4a48d3c1..00000000 --- a/pyrogram/client/session_storage/json.py +++ /dev/null @@ -1,63 +0,0 @@ -# Pyrogram - Telegram MTProto API Client Library for Python -# Copyright (C) 2017-2019 Dan Tès -# -# This file is part of Pyrogram. -# -# Pyrogram is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published -# by the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Pyrogram is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with Pyrogram. If not, see . - -import base64 -import json -import logging -import os -import shutil - -import pyrogram -from ..ext import utils -from . import MemorySessionStorage, SessionDoesNotExist - - -log = logging.getLogger(__name__) - -EXTENSION = '.session' - - -class JsonSessionStorage(MemorySessionStorage): - def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_name: str): - super(JsonSessionStorage, self).__init__(client) - self._session_name = session_name - - def _get_file_name(self, name: str): - if not name.endswith(EXTENSION): - name += EXTENSION - return os.path.join(self._client.workdir, name) - - def load(self): - file_path = self._get_file_name(self._session_name) - log.info('Loading JSON session from {}'.format(file_path)) - - try: - with open(file_path, encoding='utf-8') as f: - s = json.load(f) - except FileNotFoundError: - raise SessionDoesNotExist() - - self._dc_id = s["dc_id"] - self._test_mode = s["test_mode"] - self._auth_key = base64.b64decode("".join(s["auth_key"])) # join split key - self._user_id = s["user_id"] - self._date = s.get("date", 0) - self._is_bot = s.get('is_bot', self._is_bot) - - def save(self, sync=False): - pass diff --git a/pyrogram/client/session_storage/memory.py b/pyrogram/client/session_storage/memory.py deleted file mode 100644 index c0610e70..00000000 --- a/pyrogram/client/session_storage/memory.py +++ /dev/null @@ -1,115 +0,0 @@ -import pyrogram -from pyrogram.api import types -from . import SessionStorage, SessionDoesNotExist - - -class MemorySessionStorage(SessionStorage): - def __init__(self, client: 'pyrogram.client.ext.BaseClient'): - super(MemorySessionStorage, self).__init__(client) - self._dc_id = 1 - self._test_mode = None - self._auth_key = None - self._user_id = None - self._date = 0 - self._is_bot = False - self._peers_cache = {} - - def load(self): - raise SessionDoesNotExist() - - def save(self, sync=False): - pass - - @property - def dc_id(self): - return self._dc_id - - @dc_id.setter - def dc_id(self, val): - self._dc_id = val - - @property - def test_mode(self): - return self._test_mode - - @test_mode.setter - def test_mode(self, val): - self._test_mode = val - - @property - def auth_key(self): - return self._auth_key - - @auth_key.setter - def auth_key(self, val): - self._auth_key = val - - @property - def user_id(self): - return self._user_id - - @user_id.setter - def user_id(self, val): - self._user_id = val - - @property - def date(self): - return self._date - - @date.setter - def date(self, val): - self._date = val - - @property - def is_bot(self): - return self._is_bot - - @is_bot.setter - def is_bot(self, val): - self._is_bot = val - - def clear_cache(self): - keys = list(filter(lambda k: k[0] in 'up', self._peers_cache.keys())) - for key in keys: - try: - del self._peers_cache[key] - except KeyError: - pass - - def cache_peer(self, entity): - if isinstance(entity, types.User): - input_peer = types.InputPeerUser( - user_id=entity.id, - access_hash=entity.access_hash - ) - self._peers_cache['i' + str(entity.id)] = input_peer - if entity.username: - self._peers_cache['u' + entity.username.lower()] = input_peer - if entity.phone: - self._peers_cache['p' + entity.phone] = input_peer - elif isinstance(entity, (types.Chat, types.ChatForbidden)): - self._peers_cache['i-' + str(entity.id)] = types.InputPeerChat(chat_id=entity.id) - elif isinstance(entity, (types.Channel, types.ChannelForbidden)): - input_peer = types.InputPeerChannel( - channel_id=entity.id, - access_hash=entity.access_hash - ) - self._peers_cache['i-100' + str(entity.id)] = input_peer - username = getattr(entity, "username", None) - if username: - self._peers_cache['u' + username.lower()] = input_peer - - def get_peer_by_id(self, val): - return self._peers_cache['i' + str(val)] - - def get_peer_by_username(self, val): - return self._peers_cache['u' + val.lower()] - - def get_peer_by_phone(self, val): - return self._peers_cache['p' + val] - - def peers_count(self): - return len(list(filter(lambda k: k[0] == 'i', self._peers_cache.keys()))) - - def contacts_count(self): - return len(list(filter(lambda k: k[0] == 'p', self._peers_cache.keys()))) diff --git a/pyrogram/client/session_storage/sqlite/0001.sql b/pyrogram/client/session_storage/sqlite/0001.sql deleted file mode 100644 index c6c51d24..00000000 --- a/pyrogram/client/session_storage/sqlite/0001.sql +++ /dev/null @@ -1,24 +0,0 @@ -create table sessions ( - dc_id integer primary key, - test_mode integer, - auth_key blob, - user_id integer, - date integer, - is_bot integer -); - -create table peers_cache ( - id integer primary key, - hash integer, - username text, - phone integer -); - -create table migrations ( - name text primary key -); - -create index username_idx on peers_cache(username); -create index phone_idx on peers_cache(phone); - -insert into migrations (name) values ('0001'); diff --git a/pyrogram/client/session_storage/sqlite/__init__.py b/pyrogram/client/session_storage/sqlite/__init__.py deleted file mode 100644 index a16e75e8..00000000 --- a/pyrogram/client/session_storage/sqlite/__init__.py +++ /dev/null @@ -1,153 +0,0 @@ -# Pyrogram - Telegram MTProto API Client Library for Python -# Copyright (C) 2017-2019 Dan Tès -# -# This file is part of Pyrogram. -# -# Pyrogram is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published -# by the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Pyrogram is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with Pyrogram. If not, see . - -import logging -import os -import shutil -import sqlite3 -from threading import Lock - -import pyrogram -from ....api import types -from ...ext import utils -from .. import MemorySessionStorage, SessionDoesNotExist, JsonSessionStorage - - -log = logging.getLogger(__name__) - -EXTENSION = '.session' -MIGRATIONS = ['0001'] - - -class SQLiteSessionStorage(MemorySessionStorage): - def __init__(self, client: 'pyrogram.client.ext.BaseClient', session_name: str): - super(SQLiteSessionStorage, self).__init__(client) - self._session_name = session_name - self._conn = None # type: sqlite3.Connection - self._lock = Lock() - - def _get_file_name(self, name: str): - if not name.endswith(EXTENSION): - name += EXTENSION - return os.path.join(self._client.workdir, name) - - def _apply_migrations(self, new_db=False): - self._conn.execute('PRAGMA read_uncommitted = true') - migrations = MIGRATIONS.copy() - if not new_db: - cursor = self._conn.cursor() - cursor.execute('select name from migrations') - for row in cursor.fetchone(): - migrations.remove(row) - for name in migrations: - with open(os.path.join(os.path.dirname(__file__), '{}.sql'.format(name))) as script: - self._conn.executescript(script.read()) - - def _migrate_from_json(self): - jss = JsonSessionStorage(self._client, self._session_name) - jss.load() - file_path = self._get_file_name(self._session_name) - self._conn = sqlite3.connect(file_path + '.tmp') - self._apply_migrations(new_db=True) - self._dc_id, self._test_mode, self._auth_key, self._user_id, self._date, self._is_bot = \ - jss.dc_id, jss.test_mode, jss.auth_key, jss.user_id, jss.date, jss.is_bot - self.save() - self._conn.close() - shutil.move(file_path + '.tmp', file_path) - log.warning('Session was migrated from JSON, loading...') - self.load() - - def load(self): - file_path = self._get_file_name(self._session_name) - log.info('Loading SQLite session from {}'.format(file_path)) - - if os.path.isfile(file_path): - try: - self._conn = sqlite3.connect(file_path, isolation_level='EXCLUSIVE', check_same_thread=False) - self._apply_migrations() - except sqlite3.DatabaseError: - log.warning('Trying to migrate session from JSON...') - self._migrate_from_json() - return - else: - self._conn = sqlite3.connect(file_path, isolation_level='EXCLUSIVE', check_same_thread=False) - self._apply_migrations(new_db=True) - - cursor = self._conn.cursor() - cursor.execute('select dc_id, test_mode, auth_key, user_id, "date", is_bot from sessions') - row = cursor.fetchone() - if not row: - raise SessionDoesNotExist() - - self._dc_id = row[0] - self._test_mode = bool(row[1]) - self._auth_key = row[2] - self._user_id = row[3] - self._date = row[4] - self._is_bot = bool(row[5]) - - def cache_peer(self, entity): - peer_id = username = phone = access_hash = None - - if isinstance(entity, types.User): - peer_id = entity.id - username = entity.username.lower() if entity.username else None - phone = entity.phone or None - access_hash = entity.access_hash - elif isinstance(entity, (types.Chat, types.ChatForbidden)): - peer_id = -entity.id - elif isinstance(entity, (types.Channel, types.ChannelForbidden)): - peer_id = int('-100' + str(entity.id)) - username = entity.username.lower() if hasattr(entity, 'username') and entity.username else None - access_hash = entity.access_hash - - with self._lock: - self._conn.execute('insert or replace into peers_cache values (?, ?, ?, ?)', - (peer_id, access_hash, username, phone)) - - def get_peer_by_id(self, val): - cursor = self._conn.cursor() - cursor.execute('select id, hash from peers_cache where id = ?', (val,)) - row = cursor.fetchone() - if not row: - raise KeyError(val) - return utils.get_input_peer(row[0], row[1]) - - def get_peer_by_username(self, val): - cursor = self._conn.cursor() - cursor.execute('select id, hash from peers_cache where username = ?', (val,)) - row = cursor.fetchone() - if not row: - raise KeyError(val) - return utils.get_input_peer(row[0], row[1]) - - def get_peer_by_phone(self, val): - cursor = self._conn.cursor() - cursor.execute('select id, hash from peers_cache where phone = ?', (val,)) - row = cursor.fetchone() - if not row: - raise KeyError(val) - return utils.get_input_peer(row[0], row[1]) - - def save(self, sync=False): - log.info('Committing SQLite session') - with self._lock: - self._conn.execute('delete from sessions') - self._conn.execute('insert into sessions values (?, ?, ?, ?, ?, ?)', - (self._dc_id, self._test_mode, self._auth_key, self._user_id, self._date, self._is_bot)) - self._conn.commit() diff --git a/pyrogram/client/session_storage/string.py b/pyrogram/client/session_storage/string.py deleted file mode 100644 index 11051323..00000000 --- a/pyrogram/client/session_storage/string.py +++ /dev/null @@ -1,46 +0,0 @@ -import base64 -import binascii -import struct - -import pyrogram -from . import MemorySessionStorage, SessionDoesNotExist - - -class StringSessionStorage(MemorySessionStorage): - """ - Packs session data as following (forcing little-endian byte order): - Char dc_id (1 byte, unsigned) - Boolean test_mode (1 byte) - Long long user_id (8 bytes, signed) - Boolean is_bot (1 byte) - Bytes auth_key (256 bytes) - - Uses Base64 encoding for printable representation - """ - PACK_FORMAT = '. -from .abstract import SessionStorage, SessionDoesNotExist -from .memory import MemorySessionStorage -from .json import JsonSessionStorage -from .string import StringSessionStorage -from .sqlite import SQLiteSessionStorage +from .memory_storage import MemoryStorage +from .file_storage import FileStorage +from .storage import Storage diff --git a/pyrogram/client/style/html.py b/pyrogram/client/style/html.py index 894dbd6c..9c0a372c 100644 --- a/pyrogram/client/style/html.py +++ b/pyrogram/client/style/html.py @@ -31,16 +31,14 @@ from pyrogram.api.types import ( ) from pyrogram.errors import PeerIdInvalid from . import utils -from ..session_storage import SessionStorage class HTML: HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)") MENTION_RE = re.compile(r"tg://user\?id=(\d+)") - def __init__(self, session_storage: SessionStorage, client: "pyrogram.BaseClient" = None): + def __init__(self, client: "pyrogram.BaseClient" = None): self.client = client - self.session_storage = session_storage def parse(self, message: str): entities = [] @@ -56,9 +54,10 @@ class HTML: if mention: user_id = int(mention.group(1)) + try: - input_user = self.session_storage.get_peer_by_id(user_id) - except KeyError: + input_user = self.client.resolve_peer(user_id) + except PeerIdInvalid: input_user = None entity = ( diff --git a/pyrogram/client/style/markdown.py b/pyrogram/client/style/markdown.py index 68b54bbb..adb86e94 100644 --- a/pyrogram/client/style/markdown.py +++ b/pyrogram/client/style/markdown.py @@ -31,7 +31,6 @@ from pyrogram.api.types import ( ) from pyrogram.errors import PeerIdInvalid from . import utils -from ..session_storage import SessionStorage class Markdown: @@ -55,9 +54,8 @@ class Markdown: )) MENTION_RE = re.compile(r"tg://user\?id=(\d+)") - def __init__(self, session_storage: SessionStorage, client: "pyrogram.BaseClient" = None): + def __init__(self, client: "pyrogram.BaseClient" = None): self.client = client - self.session_storage = session_storage def parse(self, message: str): message = utils.add_surrogates(str(message or "")).strip() @@ -73,9 +71,10 @@ class Markdown: if mention: user_id = int(mention.group(1)) + try: - input_user = self.session_storage.get_peer_by_id(user_id) - except KeyError: + input_user = self.client.resolve_peer(user_id) + except PeerIdInvalid: input_user = None entity = ( diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index fb6e7ca3..b05b2855 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -22,10 +22,12 @@ from hashlib import sha1 from io import BytesIO from os import urandom +import pyrogram from pyrogram.api import functions, types from pyrogram.api.core import TLObject, Long, Int from pyrogram.connection import Connection from pyrogram.crypto import AES, RSA, Prime + from .internals import MsgId log = logging.getLogger(__name__) @@ -34,11 +36,11 @@ log = logging.getLogger(__name__) class Auth: MAX_RETRIES = 5 - def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, proxy: dict): + def __init__(self, client: "pyrogram.Client", dc_id: int): self.dc_id = dc_id - self.test_mode = test_mode - self.ipv6 = ipv6 - self.proxy = proxy + self.test_mode = client.storage.test_mode + self.ipv6 = client.ipv6 + self.proxy = client.proxy self.connection = None diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index bd7f0f26..5947fc0f 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -34,6 +34,7 @@ from pyrogram.api.core import Message, TLObject, MsgContainer, Long, FutureSalt, from pyrogram.connection import Connection from pyrogram.crypto import AES, KDF from pyrogram.errors import RPCError, InternalServerError, AuthKeyDuplicated + from .internals import MsgId, MsgFactory log = logging.getLogger(__name__) @@ -70,12 +71,14 @@ class Session: 64: "[64] invalid container" } - def __init__(self, - client: pyrogram, - dc_id: int, - auth_key: bytes, - is_media: bool = False, - is_cdn: bool = False): + def __init__( + self, + client: pyrogram, + dc_id: int, + auth_key: bytes, + is_media: bool = False, + is_cdn: bool = False + ): if not Session.notice_displayed: print("Pyrogram v{}, {}".format(__version__, __copyright__)) print("Licensed under the terms of the " + __license__, end="\n\n") @@ -113,8 +116,12 @@ class Session: def start(self): while True: - self.connection = Connection(self.dc_id, self.client.session_storage.test_mode, - self.client.ipv6, self.client.proxy) + self.connection = Connection( + self.dc_id, + self.client.storage.test_mode, + self.client.ipv6, + self.client.proxy + ) try: self.connection.connect()