diff --git a/mitmproxy/console/common.py b/mitmproxy/console/common.py index c7f87a8f7..9fb8b5c92 100644 --- a/mitmproxy/console/common.py +++ b/mitmproxy/console/common.py @@ -410,7 +410,7 @@ def raw_format_flow(f, focus, extended): return urwid.Pile(pile) -def format_flow(f, focus, extended=False, hostheader=False, marked=False): +def format_flow(f, focus, extended=False, hostheader=False): d = dict( intercepted = f.intercepted, acked = f.reply.acked, @@ -423,7 +423,7 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False): err_msg = f.error.msg if f.error else None, - marked = marked, + marked = f.marked, ) if f.response: if f.response.raw_content: diff --git a/mitmproxy/console/flowlist.py b/mitmproxy/console/flowlist.py index aba5759a7..43742083d 100644 --- a/mitmproxy/console/flowlist.py +++ b/mitmproxy/console/flowlist.py @@ -120,23 +120,17 @@ class ConnectionItem(urwid.WidgetWrap): self.flow, self.f, hostheader = self.master.options.showhost, - marked=self.state.flow_marked(self.flow) ) def selectable(self): return True def save_flows_prompt(self, k): - if k == "a": + if k == "l": signals.status_prompt_path.send( - prompt = "Save all flows to", + prompt = "Save listed flows to", callback = self.master.save_flows ) - elif k == "m": - signals.status_prompt_path.send( - prompt = "Save marked flows to", - callback = self.master.save_marked_flows - ) else: signals.status_prompt_path.send( prompt = "Save this flow to", @@ -197,10 +191,7 @@ class ConnectionItem(urwid.WidgetWrap): self.master.state.set_focus_flow(f) signals.flowlist_change.send(self) elif key == "m": - if self.state.flow_marked(self.flow): - self.state.set_flow_marked(self.flow, False) - else: - self.state.set_flow_marked(self.flow, True) + self.flow.marked = not self.flow.marked signals.flowlist_change.send(self) elif key == "M": if self.state.mark_filter: @@ -235,7 +226,7 @@ class ConnectionItem(urwid.WidgetWrap): ) elif key == "U": for f in self.state.flows: - self.state.set_flow_marked(f, False) + f.marked = False signals.flowlist_change.send(self) elif key == "V": if not self.flow.modified(): @@ -249,9 +240,8 @@ class ConnectionItem(urwid.WidgetWrap): self, prompt = "Save", keys = ( - ("all flows", "a"), + ("listed flows", "l"), ("this flow", "t"), - ("marked flows", "m"), ), callback = self.save_flows_prompt, ) diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py index fad4c375a..f7c99ecb9 100644 --- a/mitmproxy/console/master.py +++ b/mitmproxy/console/master.py @@ -34,6 +34,7 @@ from mitmproxy.console import palettes from mitmproxy.console import signals from mitmproxy.console import statusbar from mitmproxy.console import window +from mitmproxy.filt import FMarked from netlib import tcp, strutils EVENTLOG_SIZE = 500 @@ -48,7 +49,7 @@ class ConsoleState(flow.State): self.default_body_view = contentviews.get("Auto") self.flowsettings = weakref.WeakKeyDictionary() self.last_search = None - self.last_filter = None + self.last_filter = "" self.mark_filter = False def __setattr__(self, name, value): @@ -66,7 +67,6 @@ class ConsoleState(flow.State): def add_flow(self, f): super(ConsoleState, self).add_flow(f) self.update_focus() - self.set_flow_marked(f, False) return f def update_flow(self, f): @@ -123,48 +123,71 @@ class ConsoleState(flow.State): self.set_focus(self.focus) return ret - def filter_marked(self, m): - def actual_func(x): - if x.id in m: - return True - return False - return actual_func + def get_nearest_matching_flow(self, flow, filt): + fidx = self.view.index(flow) + dist = 1 + + fprev = fnext = True + while fprev or fnext: + fprev, _ = self.get_from_pos(fidx - dist) + fnext, _ = self.get_from_pos(fidx + dist) + + if fprev and fprev.match(filt): + return fprev + elif fnext and fnext.match(filt): + return fnext + + dist += 1 + + return None def enable_marked_filter(self): + marked_flows = [f for f in self.flows if f.marked] + if not marked_flows: + return + + marked_filter = "~%s" % FMarked.code + + # Save Focus + last_focus, _ = self.get_focus() + nearest_marked = self.get_nearest_matching_flow(last_focus, marked_filter) + self.last_filter = self.limit_txt - marked_flows = [] - for f in self.flows: - if self.flow_marked(f): - marked_flows.append(f.id) - if len(marked_flows) > 0: - f = self.filter_marked(marked_flows) - self.view._close() - self.view = flow.FlowView(self.flows, f) - self.focus = 0 - self.set_focus(self.focus) - self.mark_filter = True + self.set_limit(marked_filter) + + # Restore Focus + if last_focus.marked: + self.set_focus_flow(last_focus) + else: + self.set_focus_flow(nearest_marked) + + self.mark_filter = True def disable_marked_filter(self): - if self.last_filter is None: - self.view = flow.FlowView(self.flows, None) + marked_filter = "~%s" % FMarked.code + + # Save Focus + last_focus, _ = self.get_focus() + nearest_marked = self.get_nearest_matching_flow(last_focus, marked_filter) + + self.set_limit(self.last_filter) + self.last_filter = "" + + # Restore Focus + if last_focus.marked: + self.set_focus_flow(last_focus) else: - self.set_limit(self.last_filter) - self.focus = 0 - self.set_focus(self.focus) - self.last_filter = None + self.set_focus_flow(nearest_marked) + self.mark_filter = False def clear(self): - marked_flows = [] - for f in self.flows: - if self.flow_marked(f): - marked_flows.append(f) - + marked_flows = [f for f in self.state.view if f.marked] super(ConsoleState, self).clear() for f in marked_flows: self.add_flow(f) - self.set_flow_marked(f, True) + f.marked = True if len(self.flows.views) == 0: self.focus = None @@ -172,12 +195,6 @@ class ConsoleState(flow.State): self.focus = 0 self.set_focus(self.focus) - def flow_marked(self, flow): - return self.get_flow_setting(flow, "marked", False) - - def set_flow_marked(self, flow, marked): - self.add_flow_setting(flow, "marked", marked) - class Options(mitmproxy.options.Options): def __init__( @@ -615,13 +632,6 @@ class ConsoleMaster(flow.FlowMaster): def save_flows(self, path): return self._write_flows(path, self.state.view) - def save_marked_flows(self, path): - marked_flows = [] - for f in self.state.view: - if self.state.flow_marked(f): - marked_flows.append(f) - return self._write_flows(path, marked_flows) - def load_flows_callback(self, path): if not path: return diff --git a/mitmproxy/console/statusbar.py b/mitmproxy/console/statusbar.py index 3120fa71a..44be2b3e4 100644 --- a/mitmproxy/console/statusbar.py +++ b/mitmproxy/console/statusbar.py @@ -171,10 +171,6 @@ class StatusBar(urwid.WidgetWrap): r.append("[") r.append(("heading_key", "l")) r.append(":%s]" % self.master.state.limit_txt) - if self.master.state.mark_filter: - r.append("[") - r.append(("heading_key", "Marked Flows")) - r.append("]") if self.master.options.stickycookie: r.append("[") r.append(("heading_key", "t")) diff --git a/mitmproxy/filt.py b/mitmproxy/filt.py index fe8177572..67915e5b6 100644 --- a/mitmproxy/filt.py +++ b/mitmproxy/filt.py @@ -83,6 +83,14 @@ class FErr(_Action): return True if f.error else False +class FMarked(_Action): + code = "marked" + help = "Match marked flows" + + def __call__(self, f): + return f.marked + + class FHTTP(_Action): code = "http" help = "Match HTTP flows" @@ -401,6 +409,7 @@ filt_unary = [ FAsset, FErr, FHTTP, + FMarked, FReq, FResp, FTCP, diff --git a/mitmproxy/flow/io_compat.py b/mitmproxy/flow/io_compat.py index 8cd883c33..061bf16de 100644 --- a/mitmproxy/flow/io_compat.py +++ b/mitmproxy/flow/io_compat.py @@ -60,6 +60,7 @@ def convert_017_018(data): data = convert_unicode(data) data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address") + data["marked"] = False data["version"] = (0, 18) return data diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index f4993b7a2..f4a2b54b1 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -8,6 +8,8 @@ from mitmproxy import stateobject from mitmproxy.models.connections import ClientConnection from mitmproxy.models.connections import ServerConnection +import six + from netlib import version from typing import Optional # noqa @@ -79,6 +81,7 @@ class Flow(stateobject.StateObject): self.intercepted = False # type: bool self._backup = None # type: Optional[Flow] self.reply = None + self.marked = False # type: bool _stateobject_attributes = dict( id=str, @@ -86,7 +89,8 @@ class Flow(stateobject.StateObject): client_conn=ClientConnection, server_conn=ServerConnection, type=str, - intercepted=bool + intercepted=bool, + marked=bool, ) def get_state(self): @@ -173,3 +177,21 @@ class Flow(stateobject.StateObject): self.intercepted = False self.reply.ack() master.handle_accept_intercept(self) + + def match(self, f): + """ + Match this flow against a compiled filter expression. Returns True + if matched, False if not. + + If f is a string, it will be compiled as a filter expression. If + the expression is invalid, ValueError is raised. + """ + if isinstance(f, six.string_types): + from .. import filt + + f = filt.parse(f) + if not f: + raise ValueError("Invalid filter expression.") + if f: + return f(self) + return True diff --git a/mitmproxy/models/http.py b/mitmproxy/models/http.py index 1fd28f00e..7781e61fe 100644 --- a/mitmproxy/models/http.py +++ b/mitmproxy/models/http.py @@ -2,7 +2,6 @@ from __future__ import absolute_import, print_function, division import cgi import warnings -import six from mitmproxy.models.flow import Flow from netlib import version @@ -211,24 +210,6 @@ class HTTPFlow(Flow): f.response = self.response.copy() return f - def match(self, f): - """ - Match this flow against a compiled filter expression. Returns True - if matched, False if not. - - If f is a string, it will be compiled as a filter expression. If - the expression is invalid, ValueError is raised. - """ - if isinstance(f, six.string_types): - from .. import filt - - f = filt.parse(f) - if not f: - raise ValueError("Invalid filter expression.") - if f: - return f(self) - return True - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both request and diff --git a/mitmproxy/models/tcp.py b/mitmproxy/models/tcp.py index 6650141d0..e33475c23 100644 --- a/mitmproxy/models/tcp.py +++ b/mitmproxy/models/tcp.py @@ -7,8 +7,6 @@ from typing import List import netlib.basetypes from mitmproxy.models.flow import Flow -import six - class TCPMessage(netlib.basetypes.Serializable): @@ -55,22 +53,3 @@ class TCPFlow(Flow): def __repr__(self): return "".format(len(self.messages)) - - def match(self, f): - """ - Match this flow against a compiled filter expression. Returns True - if matched, False if not. - - If f is a string, it will be compiled as a filter expression. If - the expression is invalid, ValueError is raised. - """ - if isinstance(f, six.string_types): - from .. import filt - - f = filt.parse(f) - if not f: - raise ValueError("Invalid filter expression.") - if f: - return f(self) - - return True diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 36b212a78..74992130e 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -615,6 +615,7 @@ class TestSerialize: def test_roundtrip(self): sio = io.BytesIO() f = tutils.tflow() + f.marked = True f.request.content = bytes(bytearray(range(256))) w = flow.FlowWriter(sio) w.add(f) @@ -627,6 +628,7 @@ class TestSerialize: f2 = l[0] assert f2.get_state() == f.get_state() assert f2.request == f.request + assert f2.marked def test_load_flows(self): r = self._treader()