diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index 9bac5773c..e835340e1 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -124,6 +124,10 @@ class StatusBar(common.WWrap): def get_status(self): r = [] + if self.master.setheaders.count(): + r.append("[") + r.append(("heading_key", "H")) + r.append("eaders]") if self.master.replacehooks.count(): r.append("[") r.append(("heading_key", "R")) @@ -762,11 +766,6 @@ class ConsoleMaster(flow.FlowMaster): else: self.view_flowlist() - def set_replace(self, r): - self.replacehooks.clear() - for i in r: - self.replacehooks.add(*i) - def loop(self): changed = True try: @@ -815,6 +814,14 @@ class ConsoleMaster(flow.FlowMaster): ), self.stop_client_playback_prompt, ) + elif k == "H": + self.view_grideditor( + grideditor.SetHeadersEditor( + self, + self.setheaders.get_specs(), + self.setheaders.set + ) + ) elif k == "i": self.prompt( "Intercept filter: ", @@ -853,7 +860,7 @@ class ConsoleMaster(flow.FlowMaster): grideditor.ReplaceEditor( self, self.replacehooks.get_specs(), - self.set_replace + self.replacehooks.set ) ) elif k == "s": diff --git a/libmproxy/console/grideditor.py b/libmproxy/console/grideditor.py index 51002e773..d62cb2067 100644 --- a/libmproxy/console/grideditor.py +++ b/libmproxy/console/grideditor.py @@ -373,3 +373,14 @@ class ReplaceEditor(GridEditor): return True return False + +class SetHeadersEditor(GridEditor): + title = "Editing header set patterns" + columns = 3 + headings = ("Filter", "Header", "Value") + def is_error(self, col, val): + if col == 0: + if not filt.parse(val): + return True + return False + diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 2d016a140..6076d2023 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -35,6 +35,11 @@ class ReplaceHooks: def __init__(self): self.lst = [] + def set(self, r): + self.clear() + for i in r: + self.add(*i) + def add(self, fpatt, rex, s): """ Add a replacement hook. @@ -56,17 +61,6 @@ class ReplaceHooks: self.lst.append((fpatt, rex, s, cpatt)) return True - def remove(self, fpatt, rex, s): - """ - Remove a hook. - - patt: A string specifying a filter pattern. - func: Optional callable. If not specified, all hooks matching patt are removed. - """ - for i in range(len(self.lst)-1, -1, -1): - if (fpatt, rex, s) == self.lst[i][:3]: - del self.lst[i] - def get_specs(self): """ Retrieve the hook specifcations. Returns a list of (fpatt, rex, s) tuples. @@ -88,6 +82,59 @@ class ReplaceHooks: self.lst = [] +class SetHeaders: + def __init__(self): + self.lst = [] + + def set(self, r): + self.clear() + for i in r: + self.add(*i) + + def add(self, fpatt, header, value): + """ + Add a set header hook. + + fpatt: String specifying a filter pattern. + header: Header name. + value: Header value string + + Returns True if hook was added, False if the pattern could not be + parsed. + """ + cpatt = filt.parse(fpatt) + if not cpatt: + return False + self.lst.append((fpatt, header, value, cpatt)) + return True + + def get_specs(self): + """ + Retrieve the hook specifcations. Returns a list of (fpatt, rex, s) tuples. + """ + return [i[:3] for i in self.lst] + + def count(self): + return len(self.lst) + + def clear(self): + self.lst = [] + + def run(self, f): + for _, header, value, cpatt in self.lst: + if cpatt(f): + if f.response: + del f.response.headers[header] + else: + del f.request.headers[header] + for _, header, value, cpatt in self.lst: + if cpatt(f): + if f.response: + f.response.headers.add(header, value) + else: + f.request.headers.add(header, value) + + class ScriptContext: def __init__(self, master): self._master = master @@ -1220,6 +1267,7 @@ class FlowMaster(controller.Master): self.anticomp = False self.refresh_server_playback = False self.replacehooks = ReplaceHooks() + self.setheaders = SetHeaders() self.stream = None @@ -1426,6 +1474,7 @@ class FlowMaster(controller.Master): def handle_error(self, r): f = self.state.add_error(r) self.replacehooks.run(f) + self.setheaders.run(f) if f: self.run_script_hook("error", f) if self.client_playback: @@ -1436,6 +1485,7 @@ class FlowMaster(controller.Master): def handle_request(self, r): f = self.state.add_request(r) self.replacehooks.run(f) + self.setheaders.run(f) self.run_script_hook("request", f) self.process_new_request(f) return f @@ -1444,6 +1494,7 @@ class FlowMaster(controller.Master): f = self.state.add_response(r) if f: self.replacehooks.run(f) + self.setheaders.run(f) self.run_script_hook("response", f) if self.client_playback: self.client_playback.clear(f) diff --git a/test/test_flow.py b/test/test_flow.py index eccd11f4f..47a093607 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -1017,7 +1017,16 @@ def test_replacehooks(): h = flow.ReplaceHooks() h.add("~q", "foo", "bar") assert h.lst - h.remove("~q", "foo", "bar") + + h.set( + [ + (".*", "one", "two"), + (".*", "three", "four"), + ] + ) + assert h.count() == 2 + + h.clear() assert not h.lst h.add("~q", "foo", "bar") @@ -1026,10 +1035,6 @@ def test_replacehooks(): v = h.get_specs() assert v == [('~q', 'foo', 'bar'), ('~s', 'foo', 'bar')] assert h.count() == 2 - h.remove("~q", "foo", "bar") - assert h.count() == 1 - h.remove("~q", "foo", "bar") - assert h.count() == 1 h.clear() assert h.count() == 0 @@ -1056,3 +1061,55 @@ def test_replacehooks(): assert not h.add("~", "foo", "bar") assert not h.add("foo", "*", "bar") + +def test_setheaders(): + h = flow.SetHeaders() + h.add("~q", "foo", "bar") + assert h.lst + + h.set( + [ + (".*", "one", "two"), + (".*", "three", "four"), + ] + ) + assert h.count() == 2 + + h.clear() + assert not h.lst + + h.add("~q", "foo", "bar") + h.add("~s", "foo", "bar") + + v = h.get_specs() + assert v == [('~q', 'foo', 'bar'), ('~s', 'foo', 'bar')] + assert h.count() == 2 + h.clear() + assert h.count() == 0 + + f = tutils.tflow() + f.request.content = "foo" + h.add("~s", "foo", "bar") + h.run(f) + assert f.request.content == "foo" + + + h.clear() + h.add("~s", "one", "two") + h.add("~s", "one", "three") + f = tutils.tflow_full() + f.request.headers["one"] = ["xxx"] + f.response.headers["one"] = ["xxx"] + h.run(f) + assert f.request.headers["one"] == ["xxx"] + assert f.response.headers["one"] == ["two", "three"] + + h.clear() + h.add("~q", "one", "two") + h.add("~q", "one", "three") + f = tutils.tflow() + f.request.headers["one"] = ["xxx"] + h.run(f) + assert f.request.headers["one"] == ["two", "three"] + + assert not h.add("~", "foo", "bar")