diff --git a/pyrogram/client/storage/file_storage.py b/pyrogram/client/storage/file_storage.py index 4ee81b5b..c44c78dd 100644 --- a/pyrogram/client/storage/file_storage.py +++ b/pyrogram/client/storage/file_storage.py @@ -22,34 +22,29 @@ import logging import os import sqlite3 from pathlib import Path -from threading import Lock -from .memory_storage import MemoryStorage +from .sqlite_storage import SQLiteStorage log = logging.getLogger(__name__) -class FileStorage(MemoryStorage): +class FileStorage(SQLiteStorage): FILE_EXTENSION = ".session" def __init__(self, name: str, workdir: Path): super().__init__(name) - self.workdir = workdir self.database = workdir / (self.name + self.FILE_EXTENSION) - self.conn = None # type: sqlite3.Connection - self.lock = Lock() - # noinspection PyAttributeOutsideInit def migrate_from_json(self, session_json: dict): self.open() - self.dc_id = session_json["dc_id"] - self.test_mode = session_json["test_mode"] - self.auth_key = base64.b64decode("".join(session_json["auth_key"])) - self.user_id = session_json["user_id"] - self.date = session_json.get("date", 0) - self.is_bot = session_json.get("is_bot", False) + self.dc_id(session_json["dc_id"]) + self.test_mode(session_json["test_mode"]) + self.auth_key(base64.b64decode("".join(session_json["auth_key"]))) + self.user_id(session_json["user_id"]) + self.date(session_json.get("date", 0)) + self.is_bot(session_json.get("is_bot", False)) peers_by_id = session_json.get("peers_by_id", {}) peers_by_phone = session_json.get("peers_by_phone", {}) @@ -98,11 +93,7 @@ class FileStorage(MemoryStorage): if Path(path.name + ".OLD").is_file(): log.warning('Old session file detected: "{}.OLD". You can remove this file now'.format(path.name)) - self.conn = sqlite3.connect( - str(path), - timeout=1, - check_same_thread=False - ) + self.conn = sqlite3.connect(str(path), timeout=1, check_same_thread=False) if not file_exists: self.create() diff --git a/pyrogram/client/storage/memory_storage.py b/pyrogram/client/storage/memory_storage.py index b24fce38..00b81e7a 100644 --- a/pyrogram/client/storage/memory_storage.py +++ b/pyrogram/client/storage/memory_storage.py @@ -17,226 +17,37 @@ # along with Pyrogram. If not, see . import base64 -import inspect import logging import sqlite3 import struct -import time -from pathlib import Path -from threading import Lock -from typing import List, Tuple -from pyrogram.api import types -from pyrogram.client.storage.storage import Storage +from .sqlite_storage import SQLiteStorage log = logging.getLogger(__name__) -class MemoryStorage(Storage): - SCHEMA_VERSION = 1 - USERNAME_TTL = 8 * 60 * 60 - SESSION_STRING_FMT = ">B?256sI?" - SESSION_STRING_SIZE = 351 - +class MemoryStorage(SQLiteStorage): def __init__(self, name: str): super().__init__(name) - self.conn = None # type: sqlite3.Connection - self.lock = Lock() - - def create(self): - with self.lock, self.conn: - with open(str(Path(__file__).parent / "schema.sql"), "r") as schema: - self.conn.executescript(schema.read()) - - self.conn.execute( - "INSERT INTO version VALUES (?)", - (self.SCHEMA_VERSION,) - ) - - self.conn.execute( - "INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?)", - (1, None, None, 0, None, None) - ) - - def _import_session_string(self, session_string: str): - decoded = base64.urlsafe_b64decode(session_string + "=" * (-len(session_string) % 4)) - return struct.unpack(self.SESSION_STRING_FMT, decoded) - - def export_session_string(self): - packed = struct.pack( - self.SESSION_STRING_FMT, - self.dc_id, - self.test_mode, - self.auth_key, - self.user_id, - self.is_bot - ) - - return base64.urlsafe_b64encode(packed).decode().rstrip("=") - - # noinspection PyAttributeOutsideInit def open(self): self.conn = sqlite3.connect(":memory:", check_same_thread=False) self.create() if self.name != ":memory:": - imported_session_string = self._import_session_string(self.name) + dc_id, test_mode, auth_key, user_id, is_bot = struct.unpack( + self.SESSION_STRING_FORMAT, + base64.urlsafe_b64decode( + self.name + "=" * (-len(self.name) % 4) + ) + ) - self.dc_id, self.test_mode, self.auth_key, self.user_id, self.is_bot = imported_session_string - self.date = 0 + self.dc_id(dc_id) + self.test_mode(test_mode) + self.auth_key(auth_key) + self.user_id(user_id) + self.is_bot(is_bot) + self.date(0) - # noinspection PyAttributeOutsideInit - def save(self): - self.date = int(time.time()) - - with self.lock: - self.conn.commit() - - def close(self): - with self.lock: - self.conn.close() - - def destroy(self): + def delete(self): pass - - def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): - with self.lock: - self.conn.executemany( - "REPLACE INTO peers (id, access_hash, type, username, phone_number)" - "VALUES (?, ?, ?, ?, ?)", - peers - ) - - def clear_peers(self): - with self.lock, self.conn: - self.conn.execute( - "DELETE FROM peers" - ) - - @staticmethod - def _get_input_peer(peer_id: int, access_hash: int, peer_type: str): - if peer_type in ["user", "bot"]: - return types.InputPeerUser( - user_id=peer_id, - access_hash=access_hash - ) - - if peer_type == "group": - return types.InputPeerChat( - chat_id=-peer_id - ) - - if peer_type in ["channel", "supergroup"]: - return types.InputPeerChannel( - channel_id=int(str(peer_id)[4:]), - access_hash=access_hash - ) - - raise ValueError("Invalid peer type: {}".format(peer_type)) - - def get_peer_by_id(self, peer_id: int): - r = self.conn.execute( - "SELECT id, access_hash, type FROM peers WHERE id = ?", - (peer_id,) - ).fetchone() - - if r is None: - raise KeyError("ID not found: {}".format(peer_id)) - - return self._get_input_peer(*r) - - def get_peer_by_username(self, username: str): - r = self.conn.execute( - "SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?", - (username,) - ).fetchone() - - if r is None: - raise KeyError("Username not found: {}".format(username)) - - if abs(time.time() - r[3]) > self.USERNAME_TTL: - raise KeyError("Username expired: {}".format(username)) - - return self._get_input_peer(*r[:3]) - - def get_peer_by_phone_number(self, phone_number: str): - r = self.conn.execute( - "SELECT id, access_hash, type FROM peers WHERE phone_number = ?", - (phone_number,) - ).fetchone() - - if r is None: - raise KeyError("Phone number not found: {}".format(phone_number)) - - return self._get_input_peer(*r) - - @property - def peers_count(self): - return self.conn.execute( - "SELECT COUNT(*) FROM peers" - ).fetchone()[0] - - def _get(self): - attr = inspect.stack()[1].function - - return self.conn.execute( - "SELECT {} FROM sessions".format(attr) - ).fetchone()[0] - - def _set(self, value): - attr = inspect.stack()[1].function - - with self.lock, self.conn: - self.conn.execute( - "UPDATE sessions SET {} = ?".format(attr), - (value,) - ) - - @property - def dc_id(self): - return self._get() - - @dc_id.setter - def dc_id(self, value): - self._set(value) - - @property - def test_mode(self): - return self._get() - - @test_mode.setter - def test_mode(self, value): - self._set(value) - - @property - def auth_key(self): - return self._get() - - @auth_key.setter - def auth_key(self, value): - self._set(value) - - @property - def date(self): - return self._get() - - @date.setter - def date(self, value): - self._set(value) - - @property - def user_id(self): - return self._get() - - @user_id.setter - def user_id(self, value): - self._set(value) - - @property - def is_bot(self): - return self._get() - - @is_bot.setter - def is_bot(self, value): - self._set(value)