From 8c7793b91a2a182ec7d7831f2c03283dad3488cc Mon Sep 17 00:00:00 2001 From: madt1m Date: Tue, 24 Jul 2018 15:57:11 +0200 Subject: [PATCH] session: temporary DB is now stored in temporary dir --- mitmproxy/addons/session.py | 14 ++++++-------- test/mitmproxy/addons/test_session.py | 8 ++++++-- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index 7f1c0025d..c49b95c42 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -1,4 +1,5 @@ import tempfile +import shutil import sqlite3 import os @@ -20,7 +21,7 @@ class SessionDB: or create a new one with optional path. :param db_path: """ - self.temp = None + self.tempdir = None self.con = None if db_path is not None and os.path.isfile(db_path): self._load_session(db_path) @@ -28,19 +29,16 @@ class SessionDB: if db_path: path = db_path else: - # We use tempfile only to generate a path, since we demand file creation to sqlite, and removal to os. - self.temp = tempfile.NamedTemporaryFile() - path = self.temp.name - self.temp.close() + self.tempdir = tempfile.mkdtemp() + path = os.path.join(self.tempdir, 'tmp.sqlite') self.con = sqlite3.connect(path) self._create_session() def __del__(self): if self.con: self.con.close() - if self.temp: - # This is a workaround to ensure portability - os.remove(self.temp.name) + if self.tempdir: + shutil.rmtree(self.tempdir) def _load_session(self, path): if not self.is_session_db(path): diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py index cb36e2830..d4b1109b0 100644 --- a/test/mitmproxy/addons/test_session.py +++ b/test/mitmproxy/addons/test_session.py @@ -8,10 +8,14 @@ from mitmproxy.utils.data import pkg_data class TestSession: - def test_session_temporary(self, tdata): + def test_session_temporary(self): s = session.SessionDB() - filename = s.temp.name + td = s.tempdir + filename = os.path.join(td, 'tmp.sqlite') assert session.SessionDB.is_session_db(filename) + assert os.path.isdir(td) + del s + assert not os.path.isdir(td) def test_session_not_valid(self, tdata): path = tdata.path('mitmproxy/data/') + '/test_snv.sqlite'