diff --git a/libmproxy/console.py b/libmproxy/console.py index 6fd8c2caf..6a0c91254 100644 --- a/libmproxy/console.py +++ b/libmproxy/console.py @@ -208,7 +208,7 @@ class ConnectionItem(WWrap): class ConnectionListView(urwid.ListWalker): def __init__(self, master, state): self.master, self.state = master, state - if self.state.flow_list: + if self.state.flow_count(): self.set_focus(0) def get_focus(self): @@ -709,7 +709,7 @@ class StatusBar(WWrap): self.message("") t = [ - ('statusbar_text', ("[%s]"%len(self.master.state.flow_list)).ljust(7)), + ('statusbar_text', ("[%s]"%self.master.state.flow_count()).ljust(7)), ] t.extend(self.get_status()) @@ -1657,8 +1657,7 @@ class ConsoleMaster(flow.FlowMaster): self.refresh_server_playback = not self.refresh_server_playback def shutdown(self): - for i in self.state.flow_list: - i.kill(self) + self.state.killall(self) controller.Master.shutdown(self) def sync_list_view(self): @@ -1681,7 +1680,7 @@ class ConsoleMaster(flow.FlowMaster): self.statusbar.refresh_connection(c) def process_flow(self, f, r): - if f.match(self.state.intercept) and not f.request.is_replay(): + if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay(): f.intercept() else: r.ack() diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 8267ff43f..c15f45665 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -301,7 +301,8 @@ class Flow: return pattern(self.response) elif self.request: return pattern(self.request) - return False + else: + return True def kill(self, master): self.error = proxy.Error(self.request, "Connection killed") @@ -339,20 +340,26 @@ class Flow: class State: def __init__(self): self.client_connections = [] - self.flow_map = {} - self.flow_list = [] + + self._flow_map = {} + self._flow_list = [] + self.view = [] # These are compiled filt expressions: - self.limit = None + self._limit = None self.intercept = None - self.limit_txt = None + self._limit_txt = None + + @property + def limit_txt(self): + return self._limit_txt def flow_count(self): - return len(self.flow_map) + return len(self._flow_map) def active_flow_count(self): c = 0 - for i in self.flow_list: + for i in self._flow_list: if not i.response and not i.error: c += 1 return c @@ -371,18 +378,22 @@ class State: Add a request to the state. Returns the matching flow. """ f = Flow(req) - self.flow_list.append(f) - self.flow_map[req] = f + self._flow_list.append(f) + self._flow_map[req] = f + if f.match(self._limit): + self.view.append(f) return f def add_response(self, resp): """ Add a response to the state. Returns the matching flow. """ - f = self.flow_map.get(resp.request) + f = self._flow_map.get(resp.request) if not f: return False f.response = resp + if f.match(self._limit) and not f in self.view: + self.view.append(f) return f def add_error(self, err): @@ -390,27 +401,31 @@ class State: Add an error response to the state. Returns the matching flow, or None if there isn't one. """ - f = self.flow_map.get(err.request) if err.request else None + f = self._flow_map.get(err.request) if err.request else None if not f: return None f.error = err + if f.match(self._limit) and not f in self.view: + self.view.append(f) return f def load_flows(self, flows): - self.flow_list.extend(flows) + self._flow_list.extend(flows) for i in flows: - self.flow_map[i.request] = i + self._flow_map[i.request] = i + self.recalculate_view() def set_limit(self, txt): if txt: f = filt.parse(txt) if not f: return "Invalid filter expression." - self.limit = f - self.limit_txt = txt + self._limit = f + self._limit_txt = txt else: - self.limit = None - self.limit_txt = None + self._limit = None + self._limit_txt = None + self.recalculate_view() def set_intercept(self, txt): if txt: @@ -423,30 +438,35 @@ class State: self.intercept = None self.intercept_txt = None - @property - def view(self): - if self.limit: - return tuple([i for i in self.flow_list if i.match(self.limit)]) + def recalculate_view(self): + if self._limit: + self.view = [i for i in self._flow_list if i.match(self._limit)] else: - return tuple(self.flow_list[:]) + self.view = self._flow_list[:] def delete_flow(self, f): - if f.request in self.flow_map: - del self.flow_map[f.request] - self.flow_list.remove(f) + if f.request in self._flow_map: + del self._flow_map[f.request] + self._flow_list.remove(f) + if f.match(self._limit): + self.view.remove(f) return True def clear(self): - for i in self.flow_list[:]: + for i in self._flow_list[:]: self.delete_flow(i) def accept_all(self): - for i in self.flow_list[:]: + for i in self._flow_list[:]: i.accept_intercept() def revert(self, f): f.revert() + def killall(self, master): + for i in self._flow_list: + i.kill(master) + class FlowMaster(controller.Master): diff --git a/test/test_console.py b/test/test_console.py index 6040763a3..f0c939ed6 100644 --- a/test/test_console.py +++ b/test/test_console.py @@ -12,7 +12,7 @@ class uState(libpry.AutoTree): """ c = console.ConsoleState() f = self._add_request(c) - assert f.request in c.flow_map + assert f.request in c._flow_map assert c.get_focus() == (f, 0) def test_focus(self): diff --git a/test/test_flow.py b/test/test_flow.py index ba276248f..ec62f4c8b 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -152,7 +152,7 @@ class uFlow(libpry.AutoTree): f.response = tutils.tresp() f.request = f.response.request assert not f.match(filt.parse("~b test")) - assert not f.match(None) + assert f.match(None) def test_backup(self): f = tutils.tflow() @@ -260,23 +260,23 @@ class uState(libpry.AutoTree): f = c.add_request(req) assert f assert c.flow_count() == 1 - assert c.flow_map.get(req) + assert c._flow_map.get(req) assert c.active_flow_count() == 1 newreq = tutils.treq() assert c.add_request(newreq) - assert c.flow_map.get(newreq) + assert c._flow_map.get(newreq) assert c.active_flow_count() == 2 resp = tutils.tresp(req) assert c.add_response(resp) assert c.flow_count() == 2 - assert c.flow_map.get(resp.request) + assert c._flow_map.get(resp.request) assert c.active_flow_count() == 1 unseen_resp = tutils.tresp() assert not c.add_response(unseen_resp) - assert not c.flow_map.get(unseen_resp.request) + assert not c._flow_map.get(unseen_resp.request) assert c.active_flow_count() == 1 resp = tutils.tresp(newreq) @@ -373,7 +373,7 @@ class uState(libpry.AutoTree): c.clear() c.load_flows(flows) - assert isinstance(c.flow_list[0], flow.Flow) + assert isinstance(c._flow_list[0], flow.Flow) def test_accept_all(self): c = flow.State() @@ -418,7 +418,7 @@ class uSerialize(libpry.AutoTree): s = flow.State() fm = flow.FlowMaster(None, s) fm.load_flows(r) - assert len(s.flow_list) == 6 + assert len(s._flow_list) == 6 def test_error(self):