Add workdir parameter

This makes possible to define a custom working directory.
A working directory is used to store session files. Defaults to "."
This commit is contained in:
Dan 2018-04-19 10:06:41 +02:00
parent 20ec707aa8
commit ae98732b95
2 changed files with 16 additions and 10 deletions

View File

@ -152,7 +152,8 @@ class Client:
force_sms: bool = False, force_sms: bool = False,
first_name: str = None, first_name: str = None,
last_name: str = None, last_name: str = None,
workers: int = 4): workers: int = 4,
workdir: str = "."):
self.session_name = session_name self.session_name = session_name
self.api_id = int(api_id) if api_id else None self.api_id = int(api_id) if api_id else None
self.api_hash = api_hash self.api_hash = api_hash
@ -167,6 +168,7 @@ class Client:
self.force_sms = force_sms self.force_sms = force_sms
self.workers = workers self.workers = workers
self.workdir = workdir
self.token = None self.token = None
@ -882,7 +884,7 @@ class Client:
def load_session(self): def load_session(self):
try: try:
with open("{}.session".format(self.session_name), encoding="utf-8") as f: with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), encoding="utf-8") as f:
s = json.load(f) s = json.load(f)
except FileNotFoundError: except FileNotFoundError:
self.dc_id = 1 self.dc_id = 1
@ -914,7 +916,9 @@ class Client:
auth_key = base64.b64encode(self.auth_key).decode() auth_key = base64.b64encode(self.auth_key).decode()
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)]
with open("{}.session".format(self.session_name), "w", encoding="utf-8") as f: os.makedirs(self.workdir, exist_ok=True)
with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), "w", encoding="utf-8") as f:
json.dump( json.dump(
dict( dict(
dc_id=self.dc_id, dc_id=self.dc_id,

View File

@ -23,6 +23,7 @@ import os
import shutil import shutil
import time import time
from threading import Thread, Event, Lock from threading import Thread, Event, Lock
from . import utils from . import utils
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -80,6 +81,9 @@ class Syncer:
@classmethod @classmethod
def sync(cls, client): def sync(cls, client):
temporary = os.path.join(client.workdir, "{}.sync".format(client.session_name))
persistent = os.path.join(client.workdir, "{}.session".format(client.session_name))
try: try:
auth_key = base64.b64encode(client.auth_key).decode() auth_key = base64.b64encode(client.auth_key).decode()
auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)] auth_key = [auth_key[i: i + 43] for i in range(0, len(auth_key), 43)]
@ -104,7 +108,9 @@ class Syncer:
} }
) )
with open("{}.sync".format(client.session_name), "w", encoding="utf-8") as f: os.makedirs(client.workdir, exist_ok=True)
with open(temporary, "w", encoding="utf-8") as f:
json.dump(data, f, indent=4) json.dump(data, f, indent=4)
f.flush() f.flush()
@ -112,14 +118,10 @@ class Syncer:
except Exception as e: except Exception as e:
log.critical(e, exc_info=True) log.critical(e, exc_info=True)
else: else:
shutil.move( shutil.move(temporary, persistent)
"{}.sync".format(client.session_name),
"{}.session".format(client.session_name)
)
log.info("Synced {}".format(client.session_name)) log.info("Synced {}".format(client.session_name))
finally: finally:
try: try:
os.remove("{}.sync".format(client.session_name)) os.remove(temporary)
except OSError: except OSError:
pass pass