diff --git a/libmproxy/console.py b/libmproxy/console.py index fe0326e1a..fbd3617a8 100644 --- a/libmproxy/console.py +++ b/libmproxy/console.py @@ -160,7 +160,7 @@ class ConnectionItem(WWrap): self.master.statusbar.message("Can't delete connection mid-intercept.") self.master.sync_list_view() elif key == "r": - r = self.state.replay(self.flow, self.master.masterq) + r = self.state.replay_request(self.flow, self.master.masterq) if r: self.master.statusbar.message(r) self.master.sync_list_view() @@ -511,7 +511,7 @@ class ConnectionView(WWrap): elif key == "p": self.master.view_prev_flow(self.flow) elif key == "r": - r = self.state.replay(self.flow, self.master.masterq) + r = self.state.replay_request(self.flow, self.master.masterq) if r: self.master.statusbar.message(r) self.master.refresh_connection(self.flow) diff --git a/libmproxy/flow.py b/libmproxy/flow.py index d081e66c0..16a2714c6 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -2,7 +2,7 @@ This module provides more sophisticated flow tracking. These match requests with their responses, and provide filtering and interception facilities. """ -import subprocess, base64, sys, json +import subprocess, base64, sys, json, hashlib import proxy, threading, netstring import controller @@ -14,7 +14,7 @@ class RunException(Exception): # begin nocover -class ReplayThread(threading.Thread): +class RequestReplayThread(threading.Thread): def __init__(self, flow, masterq): self.flow, self.masterq = flow, masterq threading.Thread.__init__(self) @@ -31,6 +31,49 @@ class ReplayThread(threading.Thread): # end nocover +class ServerPlaybackState: + def __init__(self): + self.fmap = {} + + def __len__(self): + return sum([len(i) for i in self.fmap.values()]) + + def load(self, flows): + """ + Load a sequence of flows. We assume that the sequence is in + chronological order. + """ + for i in flows: + h = self._hash(i) + l = self.fmap.setdefault(self._hash(i), []) + l.append(i) + + def _hash(self, flow): + """ + Calculates a loose hash of the flow request. + """ + r = flow.request + key = [ + str(r.host), + str(r.port), + str(r.scheme), + str(r.method), + str(r.path), + str(r.content), + ] + return hashlib.sha256(repr(key)).digest() + + def next_flow(self, request): + """ + Returns the next flow object, or None if no matching flow was + found. + """ + l = self.fmap.get(self._hash(request)) + if not l: + return None + return l.pop(0) + + class Flow: def __init__(self, request): self.request = request @@ -262,7 +305,7 @@ class State: def revert(self, f): f.revert() - def replay(self, f, masterq): + def replay_request(self, f, masterq): """ Returns None if successful, or error message if not. """ @@ -276,7 +319,7 @@ class State: f.request.headers["content-length"] = [str(len(f.request.content))] f.response = None f.error = None - rt = ReplayThread(f, masterq) + rt = RequestReplayThread(f, masterq) rt.start() #end nocover diff --git a/test/test_flow.py b/test/test_flow.py index 7714ad977..adfeda6ea 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -3,6 +3,43 @@ from libmproxy import console, proxy, filt, flow import utils import libpry + +class uServerPlaybackState(libpry.AutoTree): + def test_hash(self): + s = flow.ServerPlaybackState() + r = utils.tflow() + r2 = utils.tflow() + + assert s._hash(r) + assert s._hash(r) == s._hash(r2) + r.request.headers["foo"] = ["bar"] + assert s._hash(r) == s._hash(r2) + r.request.path = "voing" + assert s._hash(r) != s._hash(r2) + + def test_load(self): + s = flow.ServerPlaybackState() + r = utils.tflow() + r.request.headers["key"] = ["one"] + + r2 = utils.tflow() + r2.request.headers["key"] = ["two"] + + s.load([r, r2]) + assert len(s) == 2 + assert len(s.fmap.keys()) == 1 + + n = s.next_flow(r) + assert n.request.headers["key"] == ["one"] + assert len(s) == 1 + + n = s.next_flow(r) + assert n.request.headers["key"] == ["two"] + assert len(s) == 0 + + assert not s.next_flow(r) + + class uFlow(libpry.AutoTree): def test_run_script(self): f = utils.tflow() @@ -275,6 +312,7 @@ class uFlowMaster(libpry.AutoTree): tests = [ + uServerPlaybackState(), uFlow(), uState(), uSerialize(), diff --git a/test/test_netstring.py b/test/test_netstring.py index 5146d1507..482859600 100644 --- a/test/test_netstring.py +++ b/test/test_netstring.py @@ -3,7 +3,6 @@ from cStringIO import StringIO import libpry - class uNetstring(libpry.AutoTree): def setUp(self): self.test_data = "Netstring module by Will McGugan"