finalize Reply semantics, fix tests

This commit is contained in:
Maximilian Hils 2016-08-09 20:26:24 -07:00
parent f719644aa1
commit 818840f553
17 changed files with 183 additions and 68 deletions

View File

@ -41,9 +41,9 @@ class Replace:
f.request.replace(rex, s)
def request(self, flow):
if not flow.reply.acked:
if not flow.reply.has_message:
self.execute(flow)
def response(self, flow):
if not flow.reply.acked:
if not flow.reply.has_message:
self.execute(flow)

View File

@ -31,9 +31,9 @@ class SetHeaders:
hdrs.add(header, value)
def request(self, flow):
if not flow.reply.acked:
if not flow.reply.has_message:
self.run(flow, flow.request.headers)
def response(self, flow):
if not flow.reply.acked:
if not flow.reply.has_message:
self.run(flow, flow.response.headers)

View File

@ -413,7 +413,7 @@ def raw_format_flow(f, focus, extended):
def format_flow(f, focus, extended=False, hostheader=False):
d = dict(
intercepted = f.intercepted,
acked = f.reply.acked,
acked = f.reply.state == "committed",
req_timestamp = f.request.timestamp_start,
req_is_replay = f.request.is_replay,

View File

@ -182,7 +182,7 @@ class ConnectionItem(urwid.WidgetWrap):
self.flow.accept_intercept(self.master)
signals.flowlist_change.send(self)
elif key == "d":
if not self.flow.reply.acked:
if self.flow.reply and self.flow.reply.state != "committed":
self.flow.kill(self.master)
self.state.delete_flow(self.flow)
signals.flowlist_change.send(self)
@ -246,7 +246,7 @@ class ConnectionItem(urwid.WidgetWrap):
callback = self.save_flows_prompt,
)
elif key == "X":
if not self.flow.reply.acked:
if self.flow.reply and self.flow.reply.state != "committed":
self.flow.kill(self.master)
elif key == "enter":
if self.flow.request:

View File

@ -148,13 +148,13 @@ class FlowView(tabs.Tabs):
signals.flow_change.connect(self.sig_flow_change)
def tab_request(self):
if self.flow.intercepted and not self.flow.reply.acked and not self.flow.response:
if self.flow.intercepted and not self.flow.response:
return "Request intercepted"
else:
return "Request"
def tab_response(self):
if self.flow.intercepted and not self.flow.reply.acked and self.flow.response:
if self.flow.intercepted and self.flow.response:
return "Response intercepted"
else:
return "Response"
@ -379,7 +379,6 @@ class FlowView(tabs.Tabs):
self.flow.request.http_version,
200, b"OK", Headers(), b""
)
self.flow.response.reply = controller.DummyReply()
message = self.flow.response
self.flow.backup()

View File

@ -736,7 +736,6 @@ class ConsoleMaster(flow.FlowMaster):
)
if should_intercept:
f.intercept(self)
f.reply.take()
signals.flowlist_change.send(self)
signals.flow_change.send(self, flow = f)

View File

@ -185,6 +185,7 @@ class Channel(object):
if g == exceptions.Kill:
raise exceptions.Kill()
return g
m.reply._state = "committed" # suppress error message in __del__
raise exceptions.Kill()
def tell(self, mtype, m):
@ -219,10 +220,15 @@ def handler(f):
# Reset the handled flag - it's common for us to feed the same object
# through handlers repeatedly, so we don't want this to persist across
# calls.
if handling and not message.reply.taken:
if message.reply.value == NO_REPLY:
if handling and message.reply.state == "handled":
message.reply.take()
if not message.reply.has_message:
message.reply.ack()
message.reply.commit()
# DummyReplys may be reused multiple times.
if isinstance(message.reply, DummyReply):
message.reply.reset()
return ret
# Mark this function as a handler wrapper
wrapper.__dict__["__handler"] = True
@ -240,7 +246,7 @@ class Reply(object):
"""
def __init__(self, obj):
self.obj = obj
self.q = queue.Queue()
self.q = queue.Queue() # type: queue.Queue
self._state = "unhandled" # "unhandled" -> "handled" -> "taken" -> "committed"
self.value = NO_REPLY # holds the reply value. May change before things are actually commited.
@ -259,13 +265,17 @@ class Reply(object):
"""
return self._state
@property
def has_message(self):
return self.value != NO_REPLY
def handle(self):
"""
Reply are handled by controller.handlers, which may be nested. The first handler takes
responsibility and handles the reply.
"""
if self.state != "unhandled":
raise exceptions.ControlException("Message is {}, but expected it to be unhandled.".format(self.state))
raise exceptions.ControlException("Reply is {}, but expected it to be unhandled.".format(self.state))
self._state = "handled"
def take(self):
@ -274,7 +284,7 @@ class Reply(object):
For example, intercepted flows are taken out so that the connection thread does not proceed.
"""
if self.state != "handled":
raise exceptions.ControlException("Message is {}, but expected it to be handled.".format(self.state))
raise exceptions.ControlException("Reply is {}, but expected it to be handled.".format(self.state))
self._state = "taken"
def commit(self):
@ -283,42 +293,48 @@ class Reply(object):
if the message is not taken or manually by the entity which called .take().
"""
if self.state != "taken":
raise exceptions.ControlException("Message is {}, but expected it to be taken.".format(self.state))
raise exceptions.ControlException("Reply is {}, but expected it to be taken.".format(self.state))
if not self.has_message:
raise exceptions.ControlException("There is no reply message.")
self._state = "committed"
self.q.put(self.value)
def ack(self):
self.send(self.obj)
def ack(self, force=False):
self.send(self.obj, force)
def kill(self):
self.send(exceptions.Kill)
def kill(self, force=False):
self.send(exceptions.Kill, force)
def send(self, msg):
def send(self, msg, force=False):
if self.state not in ("handled", "taken"):
raise exceptions.ControlException(
"Reply is currently {}, did not expect a call to .send().".format(self.state)
"Reply is {}, did not expect a call to .send().".format(self.state)
)
if self.value is not NO_REPLY:
raise exceptions.ControlException("There is already a reply for this message.")
if self.has_message and not force:
raise exceptions.ControlException("There is already a reply message.")
self.value = msg
def __del__(self):
if self.state != "comitted":
if self.state != "committed":
# This will be ignored by the interpreter, but emit a warning
raise exceptions.ControlException("Uncomitted message: %s" % self.obj)
raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
class DummyReply(Reply):
"""
A reply object that is not connected to anything. In contrast to regular Reply objects,
DummyReply objects are reset to "unhandled" after a commit so that they can be used
DummyReply objects are reset to "unhandled" at the end of an handler so that they can be used
multiple times. Useful when we need an object to seem like it has a channel,
and during testing.
"""
def __init__(self):
super(DummyReply, self).__init__(None)
def commit(self):
super(DummyReply, self).commit()
def reset(self):
if self.state != "committed":
raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
self._state = "unhandled"
self.value = NO_REPLY
def __del__(self):
pass

View File

@ -314,8 +314,7 @@ class FlowMaster(controller.Master):
return
if f not in self.state.flows: # don't add again on replay
self.state.add_flow(f)
if not f.reply.acked:
self.process_new_request(f)
self.process_new_request(f)
return f
@controller.handler
@ -331,9 +330,8 @@ class FlowMaster(controller.Master):
@controller.handler
def response(self, f):
self.state.update_flow(f)
if not f.reply.acked:
if self.client_playback:
self.client_playback.clear(f)
if self.client_playback:
self.client_playback.clear(f)
return f
def handle_intercept(self, f):

View File

@ -178,7 +178,7 @@ class FlowStore(FlowList):
def kill_all(self, master):
for f in self._list:
if not f.reply.acked:
if f.reply.state != "committed":
f.kill(master)

View File

@ -155,7 +155,12 @@ class Flow(stateobject.StateObject):
"""
self.error = Error("Connection killed")
self.intercepted = False
self.reply.kill()
# reply.state should only be handled or taken here.
# if none of this is the case, .take() will raise an exception.
if self.reply.state != "taken":
self.reply.take()
self.reply.kill(force=True)
self.reply.commit()
master.error(self)
def intercept(self, master):
@ -166,6 +171,7 @@ class Flow(stateobject.StateObject):
if self.intercepted:
return
self.intercepted = True
self.reply.take()
master.handle_intercept(self)
def accept_intercept(self, master):
@ -176,6 +182,7 @@ class Flow(stateobject.StateObject):
return
self.intercepted = False
self.reply.ack()
self.reply.commit()
master.handle_accept_intercept(self)
def match(self, f):

View File

@ -13,7 +13,7 @@ class ScriptThread(basethread.BaseThread):
def concurrent(fn):
if fn.__name__ not in controller.Events - set(["start", "configure", "tick"]):
if fn.__name__ not in controller.Events - {"start", "configure", "tick"}:
raise NotImplementedError(
"Concurrent decorator not supported for '%s' method." % fn.__name__
)
@ -21,8 +21,10 @@ def concurrent(fn):
def _concurrent(obj):
def run():
fn(obj)
if not obj.reply.acked:
obj.reply.ack()
if obj.reply.state == "taken":
if not obj.reply.has_message:
obj.reply.ack()
obj.reply.commit()
obj.reply.take()
ScriptThread(
"script.concurrent (%s)" % fn.__name__,

View File

@ -234,7 +234,7 @@ class AcceptFlow(RequestHandler):
class FlowHandler(RequestHandler):
def delete(self, flow_id):
if not self.flow.reply.acked:
if self.flow.reply.state != "committed":
self.flow.kill(self.master)
self.state.delete_flow(self.flow)
@ -438,6 +438,7 @@ class Application(tornado.web.Application):
xsrf_cookies=True,
cookie_secret=os.urandom(256),
debug=debug,
autoreload=False,
wauthenticator=wauthenticator,
)
super(Application, self).__init__(handlers, **settings)

View File

@ -183,7 +183,6 @@ class WebMaster(flow.FlowMaster):
if self.state.intercept and self.state.intercept(
f) and not f.request.is_replay:
f.intercept(self)
f.reply.take()
return f
@controller.handler

View File

@ -12,13 +12,11 @@ class MasterTest:
with master.handlecontext():
func = getattr(master, handler)
func(*message)
if message:
message[0].reply = controller.DummyReply()
def cycle(self, master, content):
f = tutils.tflow(req=netlib.tutils.treq(content=content))
l = proxy.Log("connect")
l.reply = mock.MagicMock()
l.reply = controller.DummyReply()
master.log(l)
self.invoke(master, "clientconnect", f.client_conn)
self.invoke(master, "clientconnect", f.client_conn)

View File

@ -29,7 +29,7 @@ class TestConcurrent(mastertest.MasterTest):
self.invoke(m, "request", f2)
start = time.time()
while time.time() - start < 5:
if f1.reply.acked and f2.reply.acked:
if f1.reply.state == f2.reply.state == "committed":
return
raise ValueError("Script never acked")

View File

@ -1,3 +1,4 @@
from test.mitmproxy import tutils
from threading import Thread, Event
from mock import Mock
@ -5,7 +6,7 @@ from mock import Mock
from mitmproxy import controller
from six.moves import queue
from mitmproxy.exceptions import Kill
from mitmproxy.exceptions import Kill, ControlException
from mitmproxy.proxy import DummyServer
from netlib.tutils import raises
@ -55,7 +56,7 @@ class TestChannel(object):
def test_tell(self):
q = queue.Queue()
channel = controller.Channel(q, Event())
m = Mock()
m = Mock(name="test_tell")
channel.tell("test", m)
assert q.get() == ("test", m)
assert m.reply
@ -66,12 +67,15 @@ class TestChannel(object):
def reply():
m, obj = q.get()
assert m == "test"
obj.reply.handle()
obj.reply.send(42)
obj.reply.take()
obj.reply.commit()
Thread(target=reply).start()
channel = controller.Channel(q, Event())
assert channel.ask("test", Mock()) == 42
assert channel.ask("test", Mock(name="test_ask_simple")) == 42
def test_ask_shutdown(self):
q = queue.Queue()
@ -79,31 +83,122 @@ class TestChannel(object):
done.set()
channel = controller.Channel(q, done)
with raises(Kill):
channel.ask("test", Mock())
class TestDummyReply(object):
def test_simple(self):
reply = controller.DummyReply()
assert not reply.acked
reply.ack()
assert reply.acked
channel.ask("test", Mock(name="test_ask_shutdown"))
class TestReply(object):
def test_simple(self):
reply = controller.Reply(42)
assert not reply.acked
assert reply.state == "unhandled"
reply.handle()
assert reply.state == "handled"
reply.send("foo")
assert reply.acked
assert reply.value == "foo"
reply.take()
assert reply.state == "taken"
with tutils.raises(queue.Empty):
reply.q.get_nowait()
reply.commit()
assert reply.state == "committed"
assert reply.q.get() == "foo"
def test_default(self):
reply = controller.Reply(42)
def test_kill(self):
reply = controller.Reply(43)
reply.handle()
reply.kill()
reply.take()
reply.commit()
assert reply.q.get() == Kill
def test_ack(self):
reply = controller.Reply(44)
reply.handle()
reply.ack()
assert reply.q.get() == 42
reply.take()
reply.commit()
assert reply.q.get() == 44
def test_reply_none(self):
reply = controller.Reply(42)
reply = controller.Reply(45)
reply.handle()
reply.send(None)
reply.take()
reply.commit()
assert reply.q.get() is None
def test_commit_no_reply(self):
reply = controller.Reply(46)
reply.handle()
reply.take()
with tutils.raises(ControlException):
reply.commit()
reply.ack()
reply.commit()
def test_double_send(self):
reply = controller.Reply(47)
reply.handle()
reply.send(1)
with tutils.raises(ControlException):
reply.send(2)
reply.take()
reply.commit()
def test_state_transitions(self):
states = {"unhandled", "handled", "taken", "committed"}
accept = {
"handle": {"unhandled"},
"take": {"handled"},
"commit": {"taken"},
"ack": {"handled", "taken"},
}
for fn, ok in accept.items():
for state in states:
r = controller.Reply(48)
r._state = state
if fn == "commit":
r.value = 49
if state in ok:
getattr(r, fn)()
else:
with tutils.raises(ControlException):
getattr(r, fn)()
r._state = "committed" # hide warnings on deletion
def test_del(self):
reply = controller.Reply(47)
with tutils.raises(ControlException):
reply.__del__()
reply.handle()
reply.ack()
reply.take()
reply.commit()
class TestDummyReply(object):
def test_simple(self):
reply = controller.DummyReply()
for _ in range(2):
reply.handle()
reply.ack()
reply.take()
reply.commit()
reply.reset()
assert reply.state == "unhandled"
def test_reset(self):
reply = controller.DummyReply()
reply.handle()
reply.ack()
reply.take()
reply.commit()
reply.reset()
assert reply.state == "unhandled"
def test_del(self):
reply = controller.DummyReply()
reply.__del__()

View File

@ -375,6 +375,7 @@ class TestHTTPFlow(object):
s = flow.State()
fm = flow.FlowMaster(None, None, s)
f = tutils.tflow()
f.reply.handle()
f.intercept(mock.Mock())
f.kill(fm)
for i in s.view:
@ -385,6 +386,7 @@ class TestHTTPFlow(object):
fm = flow.FlowMaster(None, None, s)
f = tutils.tflow()
f.reply.handle()
f.intercept(fm)
s.killall(fm)
@ -393,11 +395,11 @@ class TestHTTPFlow(object):
def test_accept_intercept(self):
f = tutils.tflow()
f.reply.handle()
f.intercept(mock.Mock())
assert not f.reply.acked
assert f.reply.state == "taken"
f.accept_intercept(mock.Mock())
assert f.reply.acked
assert f.reply.state == "committed"
def test_replace_unicode(self):
f = tutils.tflow(resp=True)
@ -735,7 +737,6 @@ class TestFlowMaster:
fm.clientdisconnect(f.client_conn)
f.error = Error("msg")
f.error.reply = controller.DummyReply()
fm.error(f)
fm.shutdown()