From 8cc61f00ed74fc8290b4d75cc1503275a42d5136 Mon Sep 17 00:00:00 2001 From: bakatrouble Date: Fri, 1 Mar 2019 21:23:01 +0300 Subject: [PATCH] Fix threading with sqlite storage --- .../client/session_storage/sqlite/__init__.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pyrogram/client/session_storage/sqlite/__init__.py b/pyrogram/client/session_storage/sqlite/__init__.py index 0308a4dc..a16e75e8 100644 --- a/pyrogram/client/session_storage/sqlite/__init__.py +++ b/pyrogram/client/session_storage/sqlite/__init__.py @@ -20,6 +20,7 @@ import logging import os import shutil import sqlite3 +from threading import Lock import pyrogram from ....api import types @@ -38,6 +39,7 @@ class SQLiteSessionStorage(MemorySessionStorage): super(SQLiteSessionStorage, self).__init__(client) self._session_name = session_name self._conn = None # type: sqlite3.Connection + self._lock = Lock() def _get_file_name(self, name: str): if not name.endswith(EXTENSION): @@ -45,6 +47,7 @@ class SQLiteSessionStorage(MemorySessionStorage): return os.path.join(self._client.workdir, name) def _apply_migrations(self, new_db=False): + self._conn.execute('PRAGMA read_uncommitted = true') migrations = MIGRATIONS.copy() if not new_db: cursor = self._conn.cursor() @@ -75,14 +78,14 @@ class SQLiteSessionStorage(MemorySessionStorage): if os.path.isfile(file_path): try: - self._conn = sqlite3.connect(file_path) + self._conn = sqlite3.connect(file_path, isolation_level='EXCLUSIVE', check_same_thread=False) self._apply_migrations() except sqlite3.DatabaseError: log.warning('Trying to migrate session from JSON...') self._migrate_from_json() return else: - self._conn = sqlite3.connect(file_path) + self._conn = sqlite3.connect(file_path, isolation_level='EXCLUSIVE', check_same_thread=False) self._apply_migrations(new_db=True) cursor = self._conn.cursor() @@ -113,8 +116,9 @@ class SQLiteSessionStorage(MemorySessionStorage): username = entity.username.lower() if hasattr(entity, 'username') and entity.username else None access_hash = entity.access_hash - self._conn.execute('insert or replace into peers_cache values (?, ?, ?, ?)', - (peer_id, access_hash, username, phone)) + with self._lock: + self._conn.execute('insert or replace into peers_cache values (?, ?, ?, ?)', + (peer_id, access_hash, username, phone)) def get_peer_by_id(self, val): cursor = self._conn.cursor() @@ -142,7 +146,8 @@ class SQLiteSessionStorage(MemorySessionStorage): def save(self, sync=False): log.info('Committing SQLite session') - self._conn.execute('delete from sessions') - self._conn.execute('insert into sessions values (?, ?, ?, ?, ?, ?)', - (self._dc_id, self._test_mode, self._auth_key, self._user_id, self._date, self._is_bot)) - self._conn.commit() + with self._lock: + self._conn.execute('delete from sessions') + self._conn.execute('insert into sessions values (?, ?, ?, ?, ?, ?)', + (self._dc_id, self._test_mode, self._auth_key, self._user_id, self._date, self._is_bot)) + self._conn.commit()