diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index c08097ee8..f9d3af3fa 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -1,9 +1,11 @@ +import collections import tempfile import asyncio import typing import bisect import shutil import sqlite3 +import copy import os from mitmproxy import flowfilter @@ -48,6 +50,7 @@ class SessionDB: or create a new one with optional path. :param db_path: """ + self.live_components = {} self.tempdir = None self.con = None # This is used for fast look-ups over bodies already dumped to database. @@ -71,11 +74,14 @@ class SessionDB: shutil.rmtree(self.tempdir) def __contains__(self, fid): - return fid in self._get_ids() + return any([fid == i for i in self._get_ids()]) + + def __len__(self): + ln = self.con.execute("SELECT COUNT(*) FROM flow;").fetchall()[0] + return ln[0] if ln else 0 def _get_ids(self): - with self.con as con: - return [t[0] for t in con.execute("SELECT id FROM flow;").fetchall()] + return [t[0] for t in self.con.execute("SELECT id FROM flow;").fetchall()] def _load_session(self, path): if not self.is_session_db(path): @@ -85,8 +91,8 @@ class SessionDB: def _create_session(self): script_path = pkg_data.path("io/sql/session_create.sql") qry = open(script_path, 'r').read() - with self.con: - self.con.executescript(qry) + self.con.executescript(qry) + self.con.commit() @staticmethod def is_session_db(path): @@ -110,62 +116,104 @@ class SessionDB: c.close() return False + def _disassemble(self, flow): + # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality. + # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually + # adding them back when needed. + self.live_components[flow.id] = ( + flow.client_conn.wfile, + flow.client_conn.rfile, + flow.client_conn.reply, + flow.server_conn.wfile, + flow.server_conn.rfile, + flow.server_conn.reply, + (flow.server_conn.via.wfile, flow.server_conn.via.rfile, + flow.server_conn.via.reply) if flow.server_conn.via else None, + flow.reply + ) + + def _reassemble(self, flow): + if flow.id in self.live_components: + cwf, crf, crp, swf, srf, srp, via, rep = self.live_components[flow.id] + flow.client_conn.wfile = cwf + flow.client_conn.rfile = crf + flow.client_conn.reply = crp + flow.server_conn.wfile = swf + flow.server_conn.rfile = srf + flow.server_conn.reply = srp + flow.reply = rep + if via: + flow.server_conn.via.rfile, flow.server_conn.via.wfile, flow.server_conn.via.reply = via + return flow + def store_flows(self, flows): body_buf = [] flow_buf = [] for flow in flows: - if len(flow.request.content) > self.content_threshold: - body_buf.append((flow.id, self.type_mappings["body"][1], flow.request.content)) - flow.request.content = b"" - self.body_ledger.add(flow.id) - if flow.response and flow.id not in self.body_ledger: - if len(flow.response.content) > self.content_threshold: - body_buf.append((flow.id, self.type_mappings["body"][2], flow.response.content)) - flow.response.content = b"" - flow_buf.append((flow.id, protobuf.dumps(flow))) - with self.con as con: - con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?)", flow_buf) - con.executemany("INSERT INTO body VALUES(?, ?, ?)", body_buf) + self._disassemble(flow) + f = copy.copy(flow) + f.request = copy.deepcopy(flow.request) + if flow.response: + f.response = copy.deepcopy(flow.response) + f.id = flow.id + if len(f.request.content) > self.content_threshold and f.id not in self.body_ledger: + body_buf.append((f.id, 1, f.request.content)) + f.request.content = b"" + self.body_ledger.add(f.id) + if f.response and f.id not in self.body_ledger: + if len(f.response.content) > self.content_threshold: + body_buf.append((f.id, 2, f.response.content)) + f.response.content = b"" + flow_buf.append((f.id, protobuf.dumps(f))) + self.con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?);", flow_buf) + if body_buf: + self.con.executemany("INSERT INTO body (flow_id, type_id, content) VALUES(?, ?, ?);", body_buf) + self.con.commit() def retrieve_flows(self, ids=None): flows = [] with self.con as con: if not ids: sql = "SELECT f.content, b.type_id, b.content " \ - "FROM flow f, body b " \ - "WHERE f.id = b.flow_id;" + "FROM flow f " \ + "LEFT OUTER JOIN body b ON f.id = b.flow_id;" rows = con.execute(sql).fetchall() else: sql = "SELECT f.content, b.type_id, b.content " \ - "FROM flow f, body b " \ - "WHERE f.id = b.flow_id" \ - f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])})" + "FROM flow f " \ + "LEFT OUTER JOIN body b ON f.id = b.flow_id " \ + f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])});" rows = con.execute(sql, ids).fetchall() for row in rows: flow = protobuf.loads(row[0]) - typ = self.type_mappings["body"][row[1]] - if typ and row[2]: - setattr(getattr(flow, typ), "content", row[2]) + if row[1]: + typ = self.type_mappings["body"][row[1]] + if typ and row[2]: + setattr(getattr(flow, typ), "content", row[2]) + flow = self._reassemble(flow) flows.append(flow) return flows + def clear(self): + self.con.executescript("DELETE FROM body; DELETE FROM annotation; DELETE FROM flow;") + matchall = flowfilter.parse(".") orders = [ - ("t", "time"), - ("m", "method"), - ("u", "url"), - ("z", "size") + "time", + "method", + "url", + "size" ] class Session: def __init__(self): - self.dbstore = SessionDB(ctx.options.session_path) - self._hot_store = [] - self._view = [] + self.db_store = None + self._hot_store = collections.OrderedDict() self._live_components = {} + self._view = [] self.order = orders[0] self.filter = matchall self._flush_period = 3.0 @@ -191,6 +239,7 @@ class Session: def running(self): if not self.started: self.started = True + self.db_store = SessionDB(ctx.options.session_path) loop = asyncio.get_event_loop() tasks = (self._writer, self._tweaker) loop.create_task(asyncio.gather(*(t() for t in tasks))) @@ -201,6 +250,60 @@ class Session: if "view_filter" in updated: self.set_filter(ctx.options.view_filter) + async def _writer(self): + while True: + await asyncio.sleep(self._flush_period) + tof = [] + to_dump = min(self._flush_rate, len(self._hot_store)) + for _ in range(to_dump): + tof.append(self._hot_store.popitem(last=False)[1]) + self.db_store.store_flows(tof) + + async def _tweaker(self): + while True: + await asyncio.sleep(self._tweak_period) + if len(self._hot_store) >= 3 * self._flush_rate: + self._flush_period *= 0.9 + self._flush_rate *= 1.1 + elif len(self._hot_store) < self._flush_rate: + self._flush_period *= 1.1 + self._flush_rate *= 0.9 + + def load_view(self, ids=None): + flows = [] + ids_from_store = [] + if ids is None: + ids = [fid for _, fid in self._view] + for fid in ids: + # Flow could be at the same time in database and in hot storage. We want the most updated version. + if fid in self._hot_store: + flows.append(self._hot_store[fid]) + elif fid in self.db_store: + ids_from_store.append(fid) + else: + flows.append(None) + flows += self.db_store.retrieve_flows(ids_from_store) + return flows + + def load_storage(self): + flows = [] + flows += self.db_store.retrieve_flows() + for flow in self._hot_store.values(): + flows.append(flow) + return flows + + def clear_storage(self): + self.db_store.clear() + self._hot_store.clear() + self._view = [] + + def store_count(self): + ln = 0 + for fid in self._hot_store.keys(): + if fid not in self.db_store: + ln += 1 + return ln + len(self.db_store) + def _generate_order(self, f: http.HTTPFlow) -> typing.Union[str, int, float]: o = self.order if o == "time": @@ -225,19 +328,19 @@ class Session: if order != self.order: self.order = order newview = [ - (self._generate_order(f), f.id) for f in self.dbstore.retrieve_flows([t[0] for t in self._view]) + (self._generate_order(f), f.id) for f in self.load_view() ] self._view = sorted(newview) def _refilter(self): self._view = [] - flows = self.dbstore.retrieve_flows() + flows = self.load_storage() for f in flows: if self.filter(f): self._base_add(f) - def set_filter(self, input_filter: str) -> None: - filt = flowfilter.parse(input_filter) + def set_filter(self, input_filter: typing.Optional[str]) -> None: + filt = matchall if not input_filter else flowfilter.parse(input_filter) if not filt: raise CommandError( "Invalid interception filter: %s" % filt @@ -245,54 +348,21 @@ class Session: self.filter = filt self._refilter() - async def _writer(self): - while True: - await asyncio.sleep(self._flush_period) - tof = [] - to_dump = min(self._flush_rate, len(self._hot_store)) - for _ in range(to_dump): - tof.append(self._hot_store.pop()) - self.store(tof) - - async def _tweaker(self): - while True: - await asyncio.sleep(self._tweak_period) - if len(self._hot_store) >= self._flush_rate: - self._flush_period *= 0.9 - self._flush_rate *= 0.9 - elif len(self._hot_store) < self._flush_rate: - self._flush_period *= 1.1 - self._flush_rate *= 1.1 - - def store(self, flows: typing.Sequence[http.HTTPFlow]) -> None: - # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality. - # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually - # adding them back when needed. - for f in flows: - self._live_components[f.id] = ( - f.client_conn.wfile or None, - f.client_conn.rfile or None, - f.server_conn.wfile or None, - f.server_conn.rfile or None, - f.reply or None - ) - self.dbstore.store_flows(flows) - def _base_add(self, f): - if f.id not in self._view: + if not any([f.id == t[1] for t in self._view]): o = self._generate_order(f) self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) else: o = self._generate_order(f) - self._view = [flow for flow in self._view if flow.id != f.id] + self._view = [(order, fid) for order, fid in self._view if fid != f.id] self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None: for f in flows: + if f.id in self._hot_store: + self._hot_store.pop(f.id) + self._hot_store[f.id] = f if self.filter(f): - if f.id in [f.id for f in self._hot_store]: - self._hot_store = [flow for flow in self._hot_store if flow.id != f.id] - self._hot_store.append(f) self._base_add(f) def request(self, f): diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py index d4b1109b0..41e8a4014 100644 --- a/test/mitmproxy/addons/test_session.py +++ b/test/mitmproxy/addons/test_session.py @@ -1,13 +1,37 @@ import sqlite3 +import asyncio import pytest import os +from mitmproxy import ctx +from mitmproxy import http +from mitmproxy.test import tflow, tutils +from mitmproxy.test import taddons from mitmproxy.addons import session -from mitmproxy.exceptions import SessionLoadException +from mitmproxy.exceptions import SessionLoadException, CommandError from mitmproxy.utils.data import pkg_data class TestSession: + + @staticmethod + def tft(*, method="GET", start=0): + f = tflow.tflow() + f.request.method = method + f.request.timestamp_start = start + return f + + @staticmethod + def start_session(fp=None): + s = session.Session() + tctx = taddons.context() + tctx.master.addons.add(s) + tctx.options.session_path = None + if fp: + s._flush_period = fp + s.running() + return s + def test_session_temporary(self): s = session.SessionDB() td = s.tempdir @@ -56,3 +80,130 @@ class TestSession: assert len(rows) == 1 con.close() os.remove(path) + + def test_session_order_generators(self): + s = session.Session() + tf = tflow.tflow(resp=True) + + s.order = "time" + assert s._generate_order(tf) == 946681200 + + s.order = "method" + assert s._generate_order(tf) == tf.request.method + + s.order = "url" + assert s._generate_order(tf) == tf.request.url + + s.order = "size" + assert s._generate_order(tf) == len(tf.request.raw_content) + len(tf.response.raw_content) + + def test_simple(self): + s = session.Session() + ctx.options = taddons.context() + ctx.options.session_path = None + s.running() + f = self.tft(start=1) + assert s.store_count() == 0 + s.request(f) + assert s._view == [(1, f.id)] + assert s.load_view([f.id]) == [f] + assert s.load_view(['nonexistent']) == [None] + + s.error(f) + s.response(f) + s.intercept(f) + s.resume(f) + s.kill(f) + + # Verify that flow has been updated, not duplicated + assert s._view == [(1, f.id)] + assert s.store_count() == 1 + + f2 = self.tft(start=3) + s.request(f2) + assert s._view == [(1, f.id), (3, f2.id)] + s.request(f2) + assert s._view == [(1, f.id), (3, f2.id)] + + f3 = self.tft(start=2) + s.request(f3) + assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)] + s.request(f3) + assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)] + assert s.store_count() == 3 + + s.clear_storage() + assert len(s._view) == 0 + assert s.store_count() == 0 + + def test_filter(self): + s = self.start_session() + s.request(self.tft(method="get")) + s.request(self.tft(method="put")) + s.request(self.tft(method="get")) + s.request(self.tft(method="put")) + assert len(s._view) == 4 + s.set_filter("~m get") + assert [f.request.method for f in s.load_view()] == ["GET", "GET"] + assert s.store_count() == 4 + with pytest.raises(CommandError): + s.set_filter("~notafilter") + s.set_filter(None) + assert len(s._view) == 4 + + @pytest.mark.asyncio + async def test_flush_withspecials(self): + s = self.start_session(fp=0.5) + f = self.tft() + s.request(f) + await asyncio.sleep(2) + assert len(s._hot_store) == 0 + assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))]) + + f.server_conn.via = tflow.tserver_conn() + s.request(f) + await asyncio.sleep(1) + assert len(s._hot_store) == 0 + assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))]) + + flows = [self.tft() for _ in range(500)] + s.update(flows) + fp = s._flush_period + fr = s._flush_rate + await asyncio.sleep(0.6) + assert s._flush_period < fp and s._flush_rate > fr + + @pytest.mark.asyncio + async def test_bodies(self): + # Need to test for configure + # Need to test for set_order + s = self.start_session(fp=0.5) + f = self.tft() + f2 = self.tft(start=1) + f.request.content = b"A"*1001 + s.request(f) + s.request(f2) + await asyncio.sleep(1.0) + content = s.db_store.con.execute( + "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id] + ).fetchall()[0] + assert content == (1, b"A"*1001) + assert s.db_store.body_ledger == {f.id} + f.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001)) + f2.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001)) + # Content length is wrong for some reason -- quick fix + f.response.headers['content-length'] = b"1001" + f2.response.headers['content-length'] = b"1001" + s.response(f) + s.response(f2) + await asyncio.sleep(1.0) + rows = s.db_store.con.execute( + "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id] + ).fetchall() + assert len(rows) == 1 + rows = s.db_store.con.execute( + "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f2.id] + ).fetchall() + assert len(rows) == 1 + assert s.db_store.body_ledger == {f.id} + assert all([lf.__dict__ == rf.__dict__ for lf, rf in list(zip(s.load_view(), [f, f2]))])