diff --git a/libmproxy/flow.py b/libmproxy/flow.py index a82b978c1..a4e0a79c8 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -26,34 +26,58 @@ class RequestReplayThread(threading.Thread): response = server.read_response() response.send(self.masterq) except proxy.ProxyError, v: - err = proxy.Error(self.flow.client_conn, v.msg) + err = proxy.Error(self.flow.request, v.msg) err.send(self.masterq) # end nocover +class ClientPlaybackState: + def __init__(self, flows): + self.flows = flows + self.current = None + + def count(self): + return len(self.flows) + + def clear(self, flow): + """ + A request has returned in some way - if this is the one we're + servicing, go to the next flow. + """ + if flow is self.current: + self.current = None + + def tick(self, master, testing=False): + """ + testing: Disables actual replay for testing. + """ + if self.flows and not self.current: + self.current = self.flows.pop(0) + self.current.response = None + master.handle_request(self.current.request) + if not testing: + #begin nocover + master.state.replay_request(self.current, master.masterq) + #end nocover + + class ServerPlaybackState: - def __init__(self, headers): + def __init__(self, headers, flows): """ headers: A case-insensitive list of request headers that should be included in request-response matching. """ self.headers = headers self.fmap = {} - - def count(self): - return sum([len(i) for i in self.fmap.values()]) - - def load(self, flows): - """ - Load a sequence of flows. We assume that the sequence is in - chronological order. - """ for i in flows: if i.response: h = self._hash(i) l = self.fmap.setdefault(self._hash(i), []) l.append(i) + def count(self): + return sum([len(i) for i in self.fmap.values()]) + def _hash(self, flow): """ Calculates a loose hash of the flow request. @@ -415,8 +439,7 @@ class FlowMaster(controller.Master): flows: A list of flows. kill: Boolean, should we kill requests not part of the replay? """ - self.playback = ServerPlaybackState(headers) - self.playback.load(flows) + self.playback = ServerPlaybackState(headers, flows) self.kill_nonreplay = kill def do_playback(self, flow): diff --git a/test/test_flow.py b/test/test_flow.py index 10b445dbe..f9fb58e83 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -33,10 +33,32 @@ class uStickyCookieState(libpry.AutoTree): assert "cookie" in f.request.headers +class uClientPlaybackState(libpry.AutoTree): + def test_tick(self): + first = utils.tflow() + c = flow.ClientPlaybackState( + [first, utils.tflow()] + ) + s = flow.State() + fm = flow.FlowMaster(None, s) + assert not s.flow_map + assert c.count() == 2 + c.tick(fm, testing=True) + assert s.flow_map + assert c.count() == 1 + + c.tick(fm, testing=True) + assert c.count() == 1 + + c.clear(first) + c.tick(fm, testing=True) + assert c.count() == 0 + + class uServerPlaybackState(libpry.AutoTree): def test_hash(self): - s = flow.ServerPlaybackState(None) + s = flow.ServerPlaybackState(None, []) r = utils.tflow() r2 = utils.tflow() @@ -48,7 +70,7 @@ class uServerPlaybackState(libpry.AutoTree): assert s._hash(r) != s._hash(r2) def test_headers(self): - s = flow.ServerPlaybackState(["foo"]) + s = flow.ServerPlaybackState(["foo"], []) r = utils.tflow_full() r.request.headers["foo"] = ["bar"] r2 = utils.tflow_full() @@ -63,14 +85,13 @@ class uServerPlaybackState(libpry.AutoTree): assert s._hash(r) == s._hash(r2) def test_load(self): - s = flow.ServerPlaybackState(None) r = utils.tflow_full() r.request.headers["key"] = ["one"] r2 = utils.tflow_full() r2.request.headers["key"] = ["two"] - s.load([r, r2]) + s = flow.ServerPlaybackState(None, [r, r2]) assert s.count() == 2 assert len(s.fmap.keys()) == 1 @@ -398,6 +419,7 @@ class uFlowMaster(libpry.AutoTree): tests = [ uStickyCookieState(), uServerPlaybackState(), + uClientPlaybackState(), uFlow(), uState(), uSerialize(),