diff --git a/libmproxy/flow.py b/libmproxy/flow.py index d9df7a1af..03b8b309c 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -168,8 +168,6 @@ class State: self.intercept = None def clientconnect(self, cc): - if not isinstance(cc, proxy.ClientConnect): - assert False self.client_connections.append(cc) def clientdisconnect(self, dc): @@ -182,8 +180,6 @@ class State: """ Add a request to the state. Returns the matching flow. """ - if not isinstance(req, proxy.Request): - assert False f = Flow(req) self.flow_list.insert(0, f) self.flow_map[req] = f @@ -193,8 +189,6 @@ class State: """ Add a response to the state. Returns the matching flow. """ - if not isinstance(resp, proxy.Response): - assert False f = self.flow_map.get(resp.request) if not f: return False @@ -230,14 +224,6 @@ class State: else: return tuple(self.flow_list[:]) - def get_client_conn(self, itm): - if isinstance(itm, proxy.ClientConnect): - return itm - elif hasattr(itm, "client_conn"): - return itm.client_conn - elif hasattr(itm, "request"): - return itm.request.client_conn - def delete_flow(self, f): if not f.intercepting: if f.request in self.flow_map: @@ -259,7 +245,6 @@ class State: self.delete_flow(f) def revert(self, f): - conn = self.get_client_conn(f) f.revert() def replay(self, f, masterq): @@ -271,7 +256,6 @@ class State: return "Can't replay while intercepting..." if f.request: f.backup() - conn = self.get_client_conn(f) f.request.set_replay() if f.request.content: f.request.headers["content-length"] = [str(len(f.request.content))] @@ -297,15 +281,11 @@ class FlowMaster(controller.Master): def handle_error(self, r): f = self.state.add_error(r) - if not f: - r.ack() + r.ack() return f def handle_request(self, r): - f = self.state.add_request(r) - if not f: - r.ack() - return f + return self.state.add_request(r) def handle_response(self, r): f = self.state.add_response(r) diff --git a/test/test_flow.py b/test/test_flow.py index 3998943c3..eb6a7c8c0 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -26,6 +26,7 @@ class uFlow(libpry.AutoTree): f.response = utils.tresp() f.request = f.response.request assert not f.match(filt.parse("~b test")) + assert not f.match(None) def test_backup(self): f = utils.tflow() @@ -41,8 +42,12 @@ class uFlow(libpry.AutoTree): def test_getset_state(self): f = utils.tflow() - f.response = utils.tresp() - f.request = f.response.request + f.response = utils.tresp(f.request) + state = f.get_state() + assert f == flow.Flow.from_state(state) + + f.response = None + f.error = proxy.Error(f, "error") state = f.get_state() assert f == flow.Flow.from_state(state) @@ -236,18 +241,29 @@ class uSerialize(libpry.AutoTree): class uFlowMaster(libpry.AutoTree): - def test_one(self): + def test_all(self): s = flow.State() - f = flow.FlowMaster(None, s) + fm = flow.FlowMaster(None, s) req = utils.treq() - f.handle_request(req) + fm.handle_clientconnect(req.client_conn) + + f = fm.handle_request(req) assert len(s.flow_list) == 1 resp = utils.tresp(req) - f.handle_response(resp) + fm.handle_response(resp) assert len(s.flow_list) == 1 + + rx = utils.tresp() + assert not fm.handle_response(rx) + dc = proxy.ClientDisconnect(req.client_conn) + fm.handle_clientdisconnect(dc) + + err = proxy.Error(f, "msg") + fm.handle_error(err) + tests = [