From 818840f553666c9993c7fa6fec3871d80764a282 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 9 Aug 2016 20:26:24 -0700 Subject: [PATCH] finalize Reply semantics, fix tests --- mitmproxy/builtins/replace.py | 4 +- mitmproxy/builtins/setheaders.py | 4 +- mitmproxy/console/common.py | 2 +- mitmproxy/console/flowlist.py | 4 +- mitmproxy/console/flowview.py | 5 +- mitmproxy/console/master.py | 1 - mitmproxy/controller.py | 54 ++++++---- mitmproxy/flow/master.py | 8 +- mitmproxy/flow/state.py | 2 +- mitmproxy/models/flow.py | 9 +- mitmproxy/script/concurrent.py | 8 +- mitmproxy/web/app.py | 3 +- mitmproxy/web/master.py | 1 - test/mitmproxy/mastertest.py | 4 +- test/mitmproxy/script/test_concurrent.py | 2 +- test/mitmproxy/test_controller.py | 131 +++++++++++++++++++---- test/mitmproxy/test_flow.py | 9 +- 17 files changed, 183 insertions(+), 68 deletions(-) diff --git a/mitmproxy/builtins/replace.py b/mitmproxy/builtins/replace.py index 2c94fbb53..c938d6838 100644 --- a/mitmproxy/builtins/replace.py +++ b/mitmproxy/builtins/replace.py @@ -41,9 +41,9 @@ class Replace: f.request.replace(rex, s) def request(self, flow): - if not flow.reply.acked: + if not flow.reply.has_message: self.execute(flow) def response(self, flow): - if not flow.reply.acked: + if not flow.reply.has_message: self.execute(flow) diff --git a/mitmproxy/builtins/setheaders.py b/mitmproxy/builtins/setheaders.py index 4a784a1d7..4cb9905ec 100644 --- a/mitmproxy/builtins/setheaders.py +++ b/mitmproxy/builtins/setheaders.py @@ -31,9 +31,9 @@ class SetHeaders: hdrs.add(header, value) def request(self, flow): - if not flow.reply.acked: + if not flow.reply.has_message: self.run(flow, flow.request.headers) def response(self, flow): - if not flow.reply.acked: + if not flow.reply.has_message: self.run(flow, flow.response.headers) diff --git a/mitmproxy/console/common.py b/mitmproxy/console/common.py index 2eb6a7d93..5a24e789b 100644 --- a/mitmproxy/console/common.py +++ b/mitmproxy/console/common.py @@ -413,7 +413,7 @@ def raw_format_flow(f, focus, extended): def format_flow(f, focus, extended=False, hostheader=False): d = dict( intercepted = f.intercepted, - acked = f.reply.acked, + acked = f.reply.state == "committed", req_timestamp = f.request.timestamp_start, req_is_replay = f.request.is_replay, diff --git a/mitmproxy/console/flowlist.py b/mitmproxy/console/flowlist.py index 12caf3157..571119edf 100644 --- a/mitmproxy/console/flowlist.py +++ b/mitmproxy/console/flowlist.py @@ -182,7 +182,7 @@ class ConnectionItem(urwid.WidgetWrap): self.flow.accept_intercept(self.master) signals.flowlist_change.send(self) elif key == "d": - if not self.flow.reply.acked: + if self.flow.reply and self.flow.reply.state != "committed": self.flow.kill(self.master) self.state.delete_flow(self.flow) signals.flowlist_change.send(self) @@ -246,7 +246,7 @@ class ConnectionItem(urwid.WidgetWrap): callback = self.save_flows_prompt, ) elif key == "X": - if not self.flow.reply.acked: + if self.flow.reply and self.flow.reply.state != "committed": self.flow.kill(self.master) elif key == "enter": if self.flow.request: diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py index 1c3c4e980..22314396d 100644 --- a/mitmproxy/console/flowview.py +++ b/mitmproxy/console/flowview.py @@ -148,13 +148,13 @@ class FlowView(tabs.Tabs): signals.flow_change.connect(self.sig_flow_change) def tab_request(self): - if self.flow.intercepted and not self.flow.reply.acked and not self.flow.response: + if self.flow.intercepted and not self.flow.response: return "Request intercepted" else: return "Request" def tab_response(self): - if self.flow.intercepted and not self.flow.reply.acked and self.flow.response: + if self.flow.intercepted and self.flow.response: return "Response intercepted" else: return "Response" @@ -379,7 +379,6 @@ class FlowView(tabs.Tabs): self.flow.request.http_version, 200, b"OK", Headers(), b"" ) - self.flow.response.reply = controller.DummyReply() message = self.flow.response self.flow.backup() diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py index 18a4c1f05..a6942ca40 100644 --- a/mitmproxy/console/master.py +++ b/mitmproxy/console/master.py @@ -736,7 +736,6 @@ class ConsoleMaster(flow.FlowMaster): ) if should_intercept: f.intercept(self) - f.reply.take() signals.flowlist_change.send(self) signals.flow_change.send(self, flow = f) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index cbac25d85..5ef81f8ce 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -185,6 +185,7 @@ class Channel(object): if g == exceptions.Kill: raise exceptions.Kill() return g + m.reply._state = "committed" # suppress error message in __del__ raise exceptions.Kill() def tell(self, mtype, m): @@ -219,10 +220,15 @@ def handler(f): # Reset the handled flag - it's common for us to feed the same object # through handlers repeatedly, so we don't want this to persist across # calls. - if handling and not message.reply.taken: - if message.reply.value == NO_REPLY: + if handling and message.reply.state == "handled": + message.reply.take() + if not message.reply.has_message: message.reply.ack() message.reply.commit() + + # DummyReplys may be reused multiple times. + if isinstance(message.reply, DummyReply): + message.reply.reset() return ret # Mark this function as a handler wrapper wrapper.__dict__["__handler"] = True @@ -240,7 +246,7 @@ class Reply(object): """ def __init__(self, obj): self.obj = obj - self.q = queue.Queue() + self.q = queue.Queue() # type: queue.Queue self._state = "unhandled" # "unhandled" -> "handled" -> "taken" -> "committed" self.value = NO_REPLY # holds the reply value. May change before things are actually commited. @@ -259,13 +265,17 @@ class Reply(object): """ return self._state + @property + def has_message(self): + return self.value != NO_REPLY + def handle(self): """ Reply are handled by controller.handlers, which may be nested. The first handler takes responsibility and handles the reply. """ if self.state != "unhandled": - raise exceptions.ControlException("Message is {}, but expected it to be unhandled.".format(self.state)) + raise exceptions.ControlException("Reply is {}, but expected it to be unhandled.".format(self.state)) self._state = "handled" def take(self): @@ -274,7 +284,7 @@ class Reply(object): For example, intercepted flows are taken out so that the connection thread does not proceed. """ if self.state != "handled": - raise exceptions.ControlException("Message is {}, but expected it to be handled.".format(self.state)) + raise exceptions.ControlException("Reply is {}, but expected it to be handled.".format(self.state)) self._state = "taken" def commit(self): @@ -283,42 +293,48 @@ class Reply(object): if the message is not taken or manually by the entity which called .take(). """ if self.state != "taken": - raise exceptions.ControlException("Message is {}, but expected it to be taken.".format(self.state)) + raise exceptions.ControlException("Reply is {}, but expected it to be taken.".format(self.state)) + if not self.has_message: + raise exceptions.ControlException("There is no reply message.") self._state = "committed" self.q.put(self.value) - def ack(self): - self.send(self.obj) + def ack(self, force=False): + self.send(self.obj, force) - def kill(self): - self.send(exceptions.Kill) + def kill(self, force=False): + self.send(exceptions.Kill, force) - def send(self, msg): + def send(self, msg, force=False): if self.state not in ("handled", "taken"): raise exceptions.ControlException( - "Reply is currently {}, did not expect a call to .send().".format(self.state) + "Reply is {}, did not expect a call to .send().".format(self.state) ) - if self.value is not NO_REPLY: - raise exceptions.ControlException("There is already a reply for this message.") + if self.has_message and not force: + raise exceptions.ControlException("There is already a reply message.") self.value = msg def __del__(self): - if self.state != "comitted": + if self.state != "committed": # This will be ignored by the interpreter, but emit a warning - raise exceptions.ControlException("Uncomitted message: %s" % self.obj) + raise exceptions.ControlException("Uncommitted reply: %s" % self.obj) class DummyReply(Reply): """ A reply object that is not connected to anything. In contrast to regular Reply objects, - DummyReply objects are reset to "unhandled" after a commit so that they can be used + DummyReply objects are reset to "unhandled" at the end of an handler so that they can be used multiple times. Useful when we need an object to seem like it has a channel, and during testing. """ def __init__(self): super(DummyReply, self).__init__(None) - def commit(self): - super(DummyReply, self).commit() + def reset(self): + if self.state != "committed": + raise exceptions.ControlException("Uncommitted reply: %s" % self.obj) self._state = "unhandled" self.value = NO_REPLY + + def __del__(self): + pass diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index 65a95e44d..e02cdc159 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -314,8 +314,7 @@ class FlowMaster(controller.Master): return if f not in self.state.flows: # don't add again on replay self.state.add_flow(f) - if not f.reply.acked: - self.process_new_request(f) + self.process_new_request(f) return f @controller.handler @@ -331,9 +330,8 @@ class FlowMaster(controller.Master): @controller.handler def response(self, f): self.state.update_flow(f) - if not f.reply.acked: - if self.client_playback: - self.client_playback.clear(f) + if self.client_playback: + self.client_playback.clear(f) return f def handle_intercept(self, f): diff --git a/mitmproxy/flow/state.py b/mitmproxy/flow/state.py index efcb2d892..6f7e6c034 100644 --- a/mitmproxy/flow/state.py +++ b/mitmproxy/flow/state.py @@ -178,7 +178,7 @@ class FlowStore(FlowList): def kill_all(self, master): for f in self._list: - if not f.reply.acked: + if f.reply.state != "committed": f.kill(master) diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index f4a2b54b1..5d61e47fc 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -155,7 +155,12 @@ class Flow(stateobject.StateObject): """ self.error = Error("Connection killed") self.intercepted = False - self.reply.kill() + # reply.state should only be handled or taken here. + # if none of this is the case, .take() will raise an exception. + if self.reply.state != "taken": + self.reply.take() + self.reply.kill(force=True) + self.reply.commit() master.error(self) def intercept(self, master): @@ -166,6 +171,7 @@ class Flow(stateobject.StateObject): if self.intercepted: return self.intercepted = True + self.reply.take() master.handle_intercept(self) def accept_intercept(self, master): @@ -176,6 +182,7 @@ class Flow(stateobject.StateObject): return self.intercepted = False self.reply.ack() + self.reply.commit() master.handle_accept_intercept(self) def match(self, f): diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index 0cc0514e7..9ed08065e 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -13,7 +13,7 @@ class ScriptThread(basethread.BaseThread): def concurrent(fn): - if fn.__name__ not in controller.Events - set(["start", "configure", "tick"]): + if fn.__name__ not in controller.Events - {"start", "configure", "tick"}: raise NotImplementedError( "Concurrent decorator not supported for '%s' method." % fn.__name__ ) @@ -21,8 +21,10 @@ def concurrent(fn): def _concurrent(obj): def run(): fn(obj) - if not obj.reply.acked: - obj.reply.ack() + if obj.reply.state == "taken": + if not obj.reply.has_message: + obj.reply.ack() + obj.reply.commit() obj.reply.take() ScriptThread( "script.concurrent (%s)" % fn.__name__, diff --git a/mitmproxy/web/app.py b/mitmproxy/web/app.py index f8f85f3dd..097de6344 100644 --- a/mitmproxy/web/app.py +++ b/mitmproxy/web/app.py @@ -234,7 +234,7 @@ class AcceptFlow(RequestHandler): class FlowHandler(RequestHandler): def delete(self, flow_id): - if not self.flow.reply.acked: + if self.flow.reply.state != "committed": self.flow.kill(self.master) self.state.delete_flow(self.flow) @@ -438,6 +438,7 @@ class Application(tornado.web.Application): xsrf_cookies=True, cookie_secret=os.urandom(256), debug=debug, + autoreload=False, wauthenticator=wauthenticator, ) super(Application, self).__init__(handlers, **settings) diff --git a/mitmproxy/web/master.py b/mitmproxy/web/master.py index 9ddb61d4c..5751c9dd5 100644 --- a/mitmproxy/web/master.py +++ b/mitmproxy/web/master.py @@ -183,7 +183,6 @@ class WebMaster(flow.FlowMaster): if self.state.intercept and self.state.intercept( f) and not f.request.is_replay: f.intercept(self) - f.reply.take() return f @controller.handler diff --git a/test/mitmproxy/mastertest.py b/test/mitmproxy/mastertest.py index dcc0dc48c..95597d2ce 100644 --- a/test/mitmproxy/mastertest.py +++ b/test/mitmproxy/mastertest.py @@ -12,13 +12,11 @@ class MasterTest: with master.handlecontext(): func = getattr(master, handler) func(*message) - if message: - message[0].reply = controller.DummyReply() def cycle(self, master, content): f = tutils.tflow(req=netlib.tutils.treq(content=content)) l = proxy.Log("connect") - l.reply = mock.MagicMock() + l.reply = controller.DummyReply() master.log(l) self.invoke(master, "clientconnect", f.client_conn) self.invoke(master, "clientconnect", f.client_conn) diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py index a5f769941..07ba1c579 100644 --- a/test/mitmproxy/script/test_concurrent.py +++ b/test/mitmproxy/script/test_concurrent.py @@ -29,7 +29,7 @@ class TestConcurrent(mastertest.MasterTest): self.invoke(m, "request", f2) start = time.time() while time.time() - start < 5: - if f1.reply.acked and f2.reply.acked: + if f1.reply.state == f2.reply.state == "committed": return raise ValueError("Script never acked") diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py index 6d4b8fe63..8fe2453d7 100644 --- a/test/mitmproxy/test_controller.py +++ b/test/mitmproxy/test_controller.py @@ -1,3 +1,4 @@ +from test.mitmproxy import tutils from threading import Thread, Event from mock import Mock @@ -5,7 +6,7 @@ from mock import Mock from mitmproxy import controller from six.moves import queue -from mitmproxy.exceptions import Kill +from mitmproxy.exceptions import Kill, ControlException from mitmproxy.proxy import DummyServer from netlib.tutils import raises @@ -55,7 +56,7 @@ class TestChannel(object): def test_tell(self): q = queue.Queue() channel = controller.Channel(q, Event()) - m = Mock() + m = Mock(name="test_tell") channel.tell("test", m) assert q.get() == ("test", m) assert m.reply @@ -66,12 +67,15 @@ class TestChannel(object): def reply(): m, obj = q.get() assert m == "test" + obj.reply.handle() obj.reply.send(42) + obj.reply.take() + obj.reply.commit() Thread(target=reply).start() channel = controller.Channel(q, Event()) - assert channel.ask("test", Mock()) == 42 + assert channel.ask("test", Mock(name="test_ask_simple")) == 42 def test_ask_shutdown(self): q = queue.Queue() @@ -79,31 +83,122 @@ class TestChannel(object): done.set() channel = controller.Channel(q, done) with raises(Kill): - channel.ask("test", Mock()) - - -class TestDummyReply(object): - def test_simple(self): - reply = controller.DummyReply() - assert not reply.acked - reply.ack() - assert reply.acked + channel.ask("test", Mock(name="test_ask_shutdown")) class TestReply(object): def test_simple(self): reply = controller.Reply(42) - assert not reply.acked + assert reply.state == "unhandled" + + reply.handle() + assert reply.state == "handled" + reply.send("foo") - assert reply.acked + assert reply.value == "foo" + + reply.take() + assert reply.state == "taken" + + with tutils.raises(queue.Empty): + reply.q.get_nowait() + reply.commit() + assert reply.state == "committed" assert reply.q.get() == "foo" - def test_default(self): - reply = controller.Reply(42) + def test_kill(self): + reply = controller.Reply(43) + reply.handle() + reply.kill() + reply.take() + reply.commit() + assert reply.q.get() == Kill + + def test_ack(self): + reply = controller.Reply(44) + reply.handle() reply.ack() - assert reply.q.get() == 42 + reply.take() + reply.commit() + assert reply.q.get() == 44 def test_reply_none(self): - reply = controller.Reply(42) + reply = controller.Reply(45) + reply.handle() reply.send(None) + reply.take() + reply.commit() assert reply.q.get() is None + + def test_commit_no_reply(self): + reply = controller.Reply(46) + reply.handle() + reply.take() + with tutils.raises(ControlException): + reply.commit() + reply.ack() + reply.commit() + + def test_double_send(self): + reply = controller.Reply(47) + reply.handle() + reply.send(1) + with tutils.raises(ControlException): + reply.send(2) + reply.take() + reply.commit() + + def test_state_transitions(self): + states = {"unhandled", "handled", "taken", "committed"} + accept = { + "handle": {"unhandled"}, + "take": {"handled"}, + "commit": {"taken"}, + "ack": {"handled", "taken"}, + } + for fn, ok in accept.items(): + for state in states: + r = controller.Reply(48) + r._state = state + if fn == "commit": + r.value = 49 + if state in ok: + getattr(r, fn)() + else: + with tutils.raises(ControlException): + getattr(r, fn)() + r._state = "committed" # hide warnings on deletion + + def test_del(self): + reply = controller.Reply(47) + with tutils.raises(ControlException): + reply.__del__() + reply.handle() + reply.ack() + reply.take() + reply.commit() + + +class TestDummyReply(object): + def test_simple(self): + reply = controller.DummyReply() + for _ in range(2): + reply.handle() + reply.ack() + reply.take() + reply.commit() + reply.reset() + assert reply.state == "unhandled" + + def test_reset(self): + reply = controller.DummyReply() + reply.handle() + reply.ack() + reply.take() + reply.commit() + reply.reset() + assert reply.state == "unhandled" + + def test_del(self): + reply = controller.DummyReply() + reply.__del__() \ No newline at end of file diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index d4bf764c6..256ee124a 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -375,6 +375,7 @@ class TestHTTPFlow(object): s = flow.State() fm = flow.FlowMaster(None, None, s) f = tutils.tflow() + f.reply.handle() f.intercept(mock.Mock()) f.kill(fm) for i in s.view: @@ -385,6 +386,7 @@ class TestHTTPFlow(object): fm = flow.FlowMaster(None, None, s) f = tutils.tflow() + f.reply.handle() f.intercept(fm) s.killall(fm) @@ -393,11 +395,11 @@ class TestHTTPFlow(object): def test_accept_intercept(self): f = tutils.tflow() - + f.reply.handle() f.intercept(mock.Mock()) - assert not f.reply.acked + assert f.reply.state == "taken" f.accept_intercept(mock.Mock()) - assert f.reply.acked + assert f.reply.state == "committed" def test_replace_unicode(self): f = tutils.tflow(resp=True) @@ -735,7 +737,6 @@ class TestFlowMaster: fm.clientdisconnect(f.client_conn) f.error = Error("msg") - f.error.reply = controller.DummyReply() fm.error(f) fm.shutdown()