tests: 97% coverage reached. Session opportunely patched after emerged defects.

This commit is contained in:
madt1m 2018-08-02 05:55:35 +02:00
parent a839d2ee2a
commit 4e0c10b88b
2 changed files with 296 additions and 75 deletions

View File

@ -1,9 +1,11 @@
import collections
import tempfile
import asyncio
import typing
import bisect
import shutil
import sqlite3
import copy
import os
from mitmproxy import flowfilter
@ -48,6 +50,7 @@ class SessionDB:
or create a new one with optional path.
:param db_path:
"""
self.live_components = {}
self.tempdir = None
self.con = None
# This is used for fast look-ups over bodies already dumped to database.
@ -71,11 +74,14 @@ class SessionDB:
shutil.rmtree(self.tempdir)
def __contains__(self, fid):
return fid in self._get_ids()
return any([fid == i for i in self._get_ids()])
def __len__(self):
ln = self.con.execute("SELECT COUNT(*) FROM flow;").fetchall()[0]
return ln[0] if ln else 0
def _get_ids(self):
with self.con as con:
return [t[0] for t in con.execute("SELECT id FROM flow;").fetchall()]
return [t[0] for t in self.con.execute("SELECT id FROM flow;").fetchall()]
def _load_session(self, path):
if not self.is_session_db(path):
@ -85,8 +91,8 @@ class SessionDB:
def _create_session(self):
script_path = pkg_data.path("io/sql/session_create.sql")
qry = open(script_path, 'r').read()
with self.con:
self.con.executescript(qry)
self.con.executescript(qry)
self.con.commit()
@staticmethod
def is_session_db(path):
@ -110,62 +116,104 @@ class SessionDB:
c.close()
return False
def _disassemble(self, flow):
# Some live components of flows cannot be serialized, but they are needed to ensure correct functionality.
# We solve this by keeping a list of tuples which "save" those components for each flow id, eventually
# adding them back when needed.
self.live_components[flow.id] = (
flow.client_conn.wfile,
flow.client_conn.rfile,
flow.client_conn.reply,
flow.server_conn.wfile,
flow.server_conn.rfile,
flow.server_conn.reply,
(flow.server_conn.via.wfile, flow.server_conn.via.rfile,
flow.server_conn.via.reply) if flow.server_conn.via else None,
flow.reply
)
def _reassemble(self, flow):
if flow.id in self.live_components:
cwf, crf, crp, swf, srf, srp, via, rep = self.live_components[flow.id]
flow.client_conn.wfile = cwf
flow.client_conn.rfile = crf
flow.client_conn.reply = crp
flow.server_conn.wfile = swf
flow.server_conn.rfile = srf
flow.server_conn.reply = srp
flow.reply = rep
if via:
flow.server_conn.via.rfile, flow.server_conn.via.wfile, flow.server_conn.via.reply = via
return flow
def store_flows(self, flows):
body_buf = []
flow_buf = []
for flow in flows:
if len(flow.request.content) > self.content_threshold:
body_buf.append((flow.id, self.type_mappings["body"][1], flow.request.content))
flow.request.content = b""
self.body_ledger.add(flow.id)
if flow.response and flow.id not in self.body_ledger:
if len(flow.response.content) > self.content_threshold:
body_buf.append((flow.id, self.type_mappings["body"][2], flow.response.content))
flow.response.content = b""
flow_buf.append((flow.id, protobuf.dumps(flow)))
with self.con as con:
con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?)", flow_buf)
con.executemany("INSERT INTO body VALUES(?, ?, ?)", body_buf)
self._disassemble(flow)
f = copy.copy(flow)
f.request = copy.deepcopy(flow.request)
if flow.response:
f.response = copy.deepcopy(flow.response)
f.id = flow.id
if len(f.request.content) > self.content_threshold and f.id not in self.body_ledger:
body_buf.append((f.id, 1, f.request.content))
f.request.content = b""
self.body_ledger.add(f.id)
if f.response and f.id not in self.body_ledger:
if len(f.response.content) > self.content_threshold:
body_buf.append((f.id, 2, f.response.content))
f.response.content = b""
flow_buf.append((f.id, protobuf.dumps(f)))
self.con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?);", flow_buf)
if body_buf:
self.con.executemany("INSERT INTO body (flow_id, type_id, content) VALUES(?, ?, ?);", body_buf)
self.con.commit()
def retrieve_flows(self, ids=None):
flows = []
with self.con as con:
if not ids:
sql = "SELECT f.content, b.type_id, b.content " \
"FROM flow f, body b " \
"WHERE f.id = b.flow_id;"
"FROM flow f " \
"LEFT OUTER JOIN body b ON f.id = b.flow_id;"
rows = con.execute(sql).fetchall()
else:
sql = "SELECT f.content, b.type_id, b.content " \
"FROM flow f, body b " \
"WHERE f.id = b.flow_id" \
f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])})"
"FROM flow f " \
"LEFT OUTER JOIN body b ON f.id = b.flow_id " \
f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])});"
rows = con.execute(sql, ids).fetchall()
for row in rows:
flow = protobuf.loads(row[0])
typ = self.type_mappings["body"][row[1]]
if typ and row[2]:
setattr(getattr(flow, typ), "content", row[2])
if row[1]:
typ = self.type_mappings["body"][row[1]]
if typ and row[2]:
setattr(getattr(flow, typ), "content", row[2])
flow = self._reassemble(flow)
flows.append(flow)
return flows
def clear(self):
self.con.executescript("DELETE FROM body; DELETE FROM annotation; DELETE FROM flow;")
matchall = flowfilter.parse(".")
orders = [
("t", "time"),
("m", "method"),
("u", "url"),
("z", "size")
"time",
"method",
"url",
"size"
]
class Session:
def __init__(self):
self.dbstore = SessionDB(ctx.options.session_path)
self._hot_store = []
self._view = []
self.db_store = None
self._hot_store = collections.OrderedDict()
self._live_components = {}
self._view = []
self.order = orders[0]
self.filter = matchall
self._flush_period = 3.0
@ -191,6 +239,7 @@ class Session:
def running(self):
if not self.started:
self.started = True
self.db_store = SessionDB(ctx.options.session_path)
loop = asyncio.get_event_loop()
tasks = (self._writer, self._tweaker)
loop.create_task(asyncio.gather(*(t() for t in tasks)))
@ -201,6 +250,60 @@ class Session:
if "view_filter" in updated:
self.set_filter(ctx.options.view_filter)
async def _writer(self):
while True:
await asyncio.sleep(self._flush_period)
tof = []
to_dump = min(self._flush_rate, len(self._hot_store))
for _ in range(to_dump):
tof.append(self._hot_store.popitem(last=False)[1])
self.db_store.store_flows(tof)
async def _tweaker(self):
while True:
await asyncio.sleep(self._tweak_period)
if len(self._hot_store) >= 3 * self._flush_rate:
self._flush_period *= 0.9
self._flush_rate *= 1.1
elif len(self._hot_store) < self._flush_rate:
self._flush_period *= 1.1
self._flush_rate *= 0.9
def load_view(self, ids=None):
flows = []
ids_from_store = []
if ids is None:
ids = [fid for _, fid in self._view]
for fid in ids:
# Flow could be at the same time in database and in hot storage. We want the most updated version.
if fid in self._hot_store:
flows.append(self._hot_store[fid])
elif fid in self.db_store:
ids_from_store.append(fid)
else:
flows.append(None)
flows += self.db_store.retrieve_flows(ids_from_store)
return flows
def load_storage(self):
flows = []
flows += self.db_store.retrieve_flows()
for flow in self._hot_store.values():
flows.append(flow)
return flows
def clear_storage(self):
self.db_store.clear()
self._hot_store.clear()
self._view = []
def store_count(self):
ln = 0
for fid in self._hot_store.keys():
if fid not in self.db_store:
ln += 1
return ln + len(self.db_store)
def _generate_order(self, f: http.HTTPFlow) -> typing.Union[str, int, float]:
o = self.order
if o == "time":
@ -225,19 +328,19 @@ class Session:
if order != self.order:
self.order = order
newview = [
(self._generate_order(f), f.id) for f in self.dbstore.retrieve_flows([t[0] for t in self._view])
(self._generate_order(f), f.id) for f in self.load_view()
]
self._view = sorted(newview)
def _refilter(self):
self._view = []
flows = self.dbstore.retrieve_flows()
flows = self.load_storage()
for f in flows:
if self.filter(f):
self._base_add(f)
def set_filter(self, input_filter: str) -> None:
filt = flowfilter.parse(input_filter)
def set_filter(self, input_filter: typing.Optional[str]) -> None:
filt = matchall if not input_filter else flowfilter.parse(input_filter)
if not filt:
raise CommandError(
"Invalid interception filter: %s" % filt
@ -245,54 +348,21 @@ class Session:
self.filter = filt
self._refilter()
async def _writer(self):
while True:
await asyncio.sleep(self._flush_period)
tof = []
to_dump = min(self._flush_rate, len(self._hot_store))
for _ in range(to_dump):
tof.append(self._hot_store.pop())
self.store(tof)
async def _tweaker(self):
while True:
await asyncio.sleep(self._tweak_period)
if len(self._hot_store) >= self._flush_rate:
self._flush_period *= 0.9
self._flush_rate *= 0.9
elif len(self._hot_store) < self._flush_rate:
self._flush_period *= 1.1
self._flush_rate *= 1.1
def store(self, flows: typing.Sequence[http.HTTPFlow]) -> None:
# Some live components of flows cannot be serialized, but they are needed to ensure correct functionality.
# We solve this by keeping a list of tuples which "save" those components for each flow id, eventually
# adding them back when needed.
for f in flows:
self._live_components[f.id] = (
f.client_conn.wfile or None,
f.client_conn.rfile or None,
f.server_conn.wfile or None,
f.server_conn.rfile or None,
f.reply or None
)
self.dbstore.store_flows(flows)
def _base_add(self, f):
if f.id not in self._view:
if not any([f.id == t[1] for t in self._view]):
o = self._generate_order(f)
self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id))
else:
o = self._generate_order(f)
self._view = [flow for flow in self._view if flow.id != f.id]
self._view = [(order, fid) for order, fid in self._view if fid != f.id]
self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id))
def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None:
for f in flows:
if f.id in self._hot_store:
self._hot_store.pop(f.id)
self._hot_store[f.id] = f
if self.filter(f):
if f.id in [f.id for f in self._hot_store]:
self._hot_store = [flow for flow in self._hot_store if flow.id != f.id]
self._hot_store.append(f)
self._base_add(f)
def request(self, f):

View File

@ -1,13 +1,37 @@
import sqlite3
import asyncio
import pytest
import os
from mitmproxy import ctx
from mitmproxy import http
from mitmproxy.test import tflow, tutils
from mitmproxy.test import taddons
from mitmproxy.addons import session
from mitmproxy.exceptions import SessionLoadException
from mitmproxy.exceptions import SessionLoadException, CommandError
from mitmproxy.utils.data import pkg_data
class TestSession:
@staticmethod
def tft(*, method="GET", start=0):
f = tflow.tflow()
f.request.method = method
f.request.timestamp_start = start
return f
@staticmethod
def start_session(fp=None):
s = session.Session()
tctx = taddons.context()
tctx.master.addons.add(s)
tctx.options.session_path = None
if fp:
s._flush_period = fp
s.running()
return s
def test_session_temporary(self):
s = session.SessionDB()
td = s.tempdir
@ -56,3 +80,130 @@ class TestSession:
assert len(rows) == 1
con.close()
os.remove(path)
def test_session_order_generators(self):
s = session.Session()
tf = tflow.tflow(resp=True)
s.order = "time"
assert s._generate_order(tf) == 946681200
s.order = "method"
assert s._generate_order(tf) == tf.request.method
s.order = "url"
assert s._generate_order(tf) == tf.request.url
s.order = "size"
assert s._generate_order(tf) == len(tf.request.raw_content) + len(tf.response.raw_content)
def test_simple(self):
s = session.Session()
ctx.options = taddons.context()
ctx.options.session_path = None
s.running()
f = self.tft(start=1)
assert s.store_count() == 0
s.request(f)
assert s._view == [(1, f.id)]
assert s.load_view([f.id]) == [f]
assert s.load_view(['nonexistent']) == [None]
s.error(f)
s.response(f)
s.intercept(f)
s.resume(f)
s.kill(f)
# Verify that flow has been updated, not duplicated
assert s._view == [(1, f.id)]
assert s.store_count() == 1
f2 = self.tft(start=3)
s.request(f2)
assert s._view == [(1, f.id), (3, f2.id)]
s.request(f2)
assert s._view == [(1, f.id), (3, f2.id)]
f3 = self.tft(start=2)
s.request(f3)
assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)]
s.request(f3)
assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)]
assert s.store_count() == 3
s.clear_storage()
assert len(s._view) == 0
assert s.store_count() == 0
def test_filter(self):
s = self.start_session()
s.request(self.tft(method="get"))
s.request(self.tft(method="put"))
s.request(self.tft(method="get"))
s.request(self.tft(method="put"))
assert len(s._view) == 4
s.set_filter("~m get")
assert [f.request.method for f in s.load_view()] == ["GET", "GET"]
assert s.store_count() == 4
with pytest.raises(CommandError):
s.set_filter("~notafilter")
s.set_filter(None)
assert len(s._view) == 4
@pytest.mark.asyncio
async def test_flush_withspecials(self):
s = self.start_session(fp=0.5)
f = self.tft()
s.request(f)
await asyncio.sleep(2)
assert len(s._hot_store) == 0
assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))])
f.server_conn.via = tflow.tserver_conn()
s.request(f)
await asyncio.sleep(1)
assert len(s._hot_store) == 0
assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))])
flows = [self.tft() for _ in range(500)]
s.update(flows)
fp = s._flush_period
fr = s._flush_rate
await asyncio.sleep(0.6)
assert s._flush_period < fp and s._flush_rate > fr
@pytest.mark.asyncio
async def test_bodies(self):
# Need to test for configure
# Need to test for set_order
s = self.start_session(fp=0.5)
f = self.tft()
f2 = self.tft(start=1)
f.request.content = b"A"*1001
s.request(f)
s.request(f2)
await asyncio.sleep(1.0)
content = s.db_store.con.execute(
"SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id]
).fetchall()[0]
assert content == (1, b"A"*1001)
assert s.db_store.body_ledger == {f.id}
f.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001))
f2.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001))
# Content length is wrong for some reason -- quick fix
f.response.headers['content-length'] = b"1001"
f2.response.headers['content-length'] = b"1001"
s.response(f)
s.response(f2)
await asyncio.sleep(1.0)
rows = s.db_store.con.execute(
"SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id]
).fetchall()
assert len(rows) == 1
rows = s.db_store.con.execute(
"SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f2.id]
).fetchall()
assert len(rows) == 1
assert s.db_store.body_ledger == {f.id}
assert all([lf.__dict__ == rf.__dict__ for lf, rf in list(zip(s.load_view(), [f, f2]))])