mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
finalize Reply semantics, fix tests
This commit is contained in:
parent
f719644aa1
commit
818840f553
@ -41,9 +41,9 @@ class Replace:
|
||||
f.request.replace(rex, s)
|
||||
|
||||
def request(self, flow):
|
||||
if not flow.reply.acked:
|
||||
if not flow.reply.has_message:
|
||||
self.execute(flow)
|
||||
|
||||
def response(self, flow):
|
||||
if not flow.reply.acked:
|
||||
if not flow.reply.has_message:
|
||||
self.execute(flow)
|
||||
|
@ -31,9 +31,9 @@ class SetHeaders:
|
||||
hdrs.add(header, value)
|
||||
|
||||
def request(self, flow):
|
||||
if not flow.reply.acked:
|
||||
if not flow.reply.has_message:
|
||||
self.run(flow, flow.request.headers)
|
||||
|
||||
def response(self, flow):
|
||||
if not flow.reply.acked:
|
||||
if not flow.reply.has_message:
|
||||
self.run(flow, flow.response.headers)
|
||||
|
@ -413,7 +413,7 @@ def raw_format_flow(f, focus, extended):
|
||||
def format_flow(f, focus, extended=False, hostheader=False):
|
||||
d = dict(
|
||||
intercepted = f.intercepted,
|
||||
acked = f.reply.acked,
|
||||
acked = f.reply.state == "committed",
|
||||
|
||||
req_timestamp = f.request.timestamp_start,
|
||||
req_is_replay = f.request.is_replay,
|
||||
|
@ -182,7 +182,7 @@ class ConnectionItem(urwid.WidgetWrap):
|
||||
self.flow.accept_intercept(self.master)
|
||||
signals.flowlist_change.send(self)
|
||||
elif key == "d":
|
||||
if not self.flow.reply.acked:
|
||||
if self.flow.reply and self.flow.reply.state != "committed":
|
||||
self.flow.kill(self.master)
|
||||
self.state.delete_flow(self.flow)
|
||||
signals.flowlist_change.send(self)
|
||||
@ -246,7 +246,7 @@ class ConnectionItem(urwid.WidgetWrap):
|
||||
callback = self.save_flows_prompt,
|
||||
)
|
||||
elif key == "X":
|
||||
if not self.flow.reply.acked:
|
||||
if self.flow.reply and self.flow.reply.state != "committed":
|
||||
self.flow.kill(self.master)
|
||||
elif key == "enter":
|
||||
if self.flow.request:
|
||||
|
@ -148,13 +148,13 @@ class FlowView(tabs.Tabs):
|
||||
signals.flow_change.connect(self.sig_flow_change)
|
||||
|
||||
def tab_request(self):
|
||||
if self.flow.intercepted and not self.flow.reply.acked and not self.flow.response:
|
||||
if self.flow.intercepted and not self.flow.response:
|
||||
return "Request intercepted"
|
||||
else:
|
||||
return "Request"
|
||||
|
||||
def tab_response(self):
|
||||
if self.flow.intercepted and not self.flow.reply.acked and self.flow.response:
|
||||
if self.flow.intercepted and self.flow.response:
|
||||
return "Response intercepted"
|
||||
else:
|
||||
return "Response"
|
||||
@ -379,7 +379,6 @@ class FlowView(tabs.Tabs):
|
||||
self.flow.request.http_version,
|
||||
200, b"OK", Headers(), b""
|
||||
)
|
||||
self.flow.response.reply = controller.DummyReply()
|
||||
message = self.flow.response
|
||||
|
||||
self.flow.backup()
|
||||
|
@ -736,7 +736,6 @@ class ConsoleMaster(flow.FlowMaster):
|
||||
)
|
||||
if should_intercept:
|
||||
f.intercept(self)
|
||||
f.reply.take()
|
||||
signals.flowlist_change.send(self)
|
||||
signals.flow_change.send(self, flow = f)
|
||||
|
||||
|
@ -185,6 +185,7 @@ class Channel(object):
|
||||
if g == exceptions.Kill:
|
||||
raise exceptions.Kill()
|
||||
return g
|
||||
m.reply._state = "committed" # suppress error message in __del__
|
||||
raise exceptions.Kill()
|
||||
|
||||
def tell(self, mtype, m):
|
||||
@ -219,10 +220,15 @@ def handler(f):
|
||||
# Reset the handled flag - it's common for us to feed the same object
|
||||
# through handlers repeatedly, so we don't want this to persist across
|
||||
# calls.
|
||||
if handling and not message.reply.taken:
|
||||
if message.reply.value == NO_REPLY:
|
||||
if handling and message.reply.state == "handled":
|
||||
message.reply.take()
|
||||
if not message.reply.has_message:
|
||||
message.reply.ack()
|
||||
message.reply.commit()
|
||||
|
||||
# DummyReplys may be reused multiple times.
|
||||
if isinstance(message.reply, DummyReply):
|
||||
message.reply.reset()
|
||||
return ret
|
||||
# Mark this function as a handler wrapper
|
||||
wrapper.__dict__["__handler"] = True
|
||||
@ -240,7 +246,7 @@ class Reply(object):
|
||||
"""
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
self.q = queue.Queue()
|
||||
self.q = queue.Queue() # type: queue.Queue
|
||||
|
||||
self._state = "unhandled" # "unhandled" -> "handled" -> "taken" -> "committed"
|
||||
self.value = NO_REPLY # holds the reply value. May change before things are actually commited.
|
||||
@ -259,13 +265,17 @@ class Reply(object):
|
||||
"""
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def has_message(self):
|
||||
return self.value != NO_REPLY
|
||||
|
||||
def handle(self):
|
||||
"""
|
||||
Reply are handled by controller.handlers, which may be nested. The first handler takes
|
||||
responsibility and handles the reply.
|
||||
"""
|
||||
if self.state != "unhandled":
|
||||
raise exceptions.ControlException("Message is {}, but expected it to be unhandled.".format(self.state))
|
||||
raise exceptions.ControlException("Reply is {}, but expected it to be unhandled.".format(self.state))
|
||||
self._state = "handled"
|
||||
|
||||
def take(self):
|
||||
@ -274,7 +284,7 @@ class Reply(object):
|
||||
For example, intercepted flows are taken out so that the connection thread does not proceed.
|
||||
"""
|
||||
if self.state != "handled":
|
||||
raise exceptions.ControlException("Message is {}, but expected it to be handled.".format(self.state))
|
||||
raise exceptions.ControlException("Reply is {}, but expected it to be handled.".format(self.state))
|
||||
self._state = "taken"
|
||||
|
||||
def commit(self):
|
||||
@ -283,42 +293,48 @@ class Reply(object):
|
||||
if the message is not taken or manually by the entity which called .take().
|
||||
"""
|
||||
if self.state != "taken":
|
||||
raise exceptions.ControlException("Message is {}, but expected it to be taken.".format(self.state))
|
||||
raise exceptions.ControlException("Reply is {}, but expected it to be taken.".format(self.state))
|
||||
if not self.has_message:
|
||||
raise exceptions.ControlException("There is no reply message.")
|
||||
self._state = "committed"
|
||||
self.q.put(self.value)
|
||||
|
||||
def ack(self):
|
||||
self.send(self.obj)
|
||||
def ack(self, force=False):
|
||||
self.send(self.obj, force)
|
||||
|
||||
def kill(self):
|
||||
self.send(exceptions.Kill)
|
||||
def kill(self, force=False):
|
||||
self.send(exceptions.Kill, force)
|
||||
|
||||
def send(self, msg):
|
||||
def send(self, msg, force=False):
|
||||
if self.state not in ("handled", "taken"):
|
||||
raise exceptions.ControlException(
|
||||
"Reply is currently {}, did not expect a call to .send().".format(self.state)
|
||||
"Reply is {}, did not expect a call to .send().".format(self.state)
|
||||
)
|
||||
if self.value is not NO_REPLY:
|
||||
raise exceptions.ControlException("There is already a reply for this message.")
|
||||
if self.has_message and not force:
|
||||
raise exceptions.ControlException("There is already a reply message.")
|
||||
self.value = msg
|
||||
|
||||
def __del__(self):
|
||||
if self.state != "comitted":
|
||||
if self.state != "committed":
|
||||
# This will be ignored by the interpreter, but emit a warning
|
||||
raise exceptions.ControlException("Uncomitted message: %s" % self.obj)
|
||||
raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
|
||||
|
||||
|
||||
class DummyReply(Reply):
|
||||
"""
|
||||
A reply object that is not connected to anything. In contrast to regular Reply objects,
|
||||
DummyReply objects are reset to "unhandled" after a commit so that they can be used
|
||||
DummyReply objects are reset to "unhandled" at the end of an handler so that they can be used
|
||||
multiple times. Useful when we need an object to seem like it has a channel,
|
||||
and during testing.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(DummyReply, self).__init__(None)
|
||||
|
||||
def commit(self):
|
||||
super(DummyReply, self).commit()
|
||||
def reset(self):
|
||||
if self.state != "committed":
|
||||
raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
|
||||
self._state = "unhandled"
|
||||
self.value = NO_REPLY
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
@ -314,7 +314,6 @@ class FlowMaster(controller.Master):
|
||||
return
|
||||
if f not in self.state.flows: # don't add again on replay
|
||||
self.state.add_flow(f)
|
||||
if not f.reply.acked:
|
||||
self.process_new_request(f)
|
||||
return f
|
||||
|
||||
@ -331,7 +330,6 @@ class FlowMaster(controller.Master):
|
||||
@controller.handler
|
||||
def response(self, f):
|
||||
self.state.update_flow(f)
|
||||
if not f.reply.acked:
|
||||
if self.client_playback:
|
||||
self.client_playback.clear(f)
|
||||
return f
|
||||
|
@ -178,7 +178,7 @@ class FlowStore(FlowList):
|
||||
|
||||
def kill_all(self, master):
|
||||
for f in self._list:
|
||||
if not f.reply.acked:
|
||||
if f.reply.state != "committed":
|
||||
f.kill(master)
|
||||
|
||||
|
||||
|
@ -155,7 +155,12 @@ class Flow(stateobject.StateObject):
|
||||
"""
|
||||
self.error = Error("Connection killed")
|
||||
self.intercepted = False
|
||||
self.reply.kill()
|
||||
# reply.state should only be handled or taken here.
|
||||
# if none of this is the case, .take() will raise an exception.
|
||||
if self.reply.state != "taken":
|
||||
self.reply.take()
|
||||
self.reply.kill(force=True)
|
||||
self.reply.commit()
|
||||
master.error(self)
|
||||
|
||||
def intercept(self, master):
|
||||
@ -166,6 +171,7 @@ class Flow(stateobject.StateObject):
|
||||
if self.intercepted:
|
||||
return
|
||||
self.intercepted = True
|
||||
self.reply.take()
|
||||
master.handle_intercept(self)
|
||||
|
||||
def accept_intercept(self, master):
|
||||
@ -176,6 +182,7 @@ class Flow(stateobject.StateObject):
|
||||
return
|
||||
self.intercepted = False
|
||||
self.reply.ack()
|
||||
self.reply.commit()
|
||||
master.handle_accept_intercept(self)
|
||||
|
||||
def match(self, f):
|
||||
|
@ -13,7 +13,7 @@ class ScriptThread(basethread.BaseThread):
|
||||
|
||||
|
||||
def concurrent(fn):
|
||||
if fn.__name__ not in controller.Events - set(["start", "configure", "tick"]):
|
||||
if fn.__name__ not in controller.Events - {"start", "configure", "tick"}:
|
||||
raise NotImplementedError(
|
||||
"Concurrent decorator not supported for '%s' method." % fn.__name__
|
||||
)
|
||||
@ -21,8 +21,10 @@ def concurrent(fn):
|
||||
def _concurrent(obj):
|
||||
def run():
|
||||
fn(obj)
|
||||
if not obj.reply.acked:
|
||||
if obj.reply.state == "taken":
|
||||
if not obj.reply.has_message:
|
||||
obj.reply.ack()
|
||||
obj.reply.commit()
|
||||
obj.reply.take()
|
||||
ScriptThread(
|
||||
"script.concurrent (%s)" % fn.__name__,
|
||||
|
@ -234,7 +234,7 @@ class AcceptFlow(RequestHandler):
|
||||
class FlowHandler(RequestHandler):
|
||||
|
||||
def delete(self, flow_id):
|
||||
if not self.flow.reply.acked:
|
||||
if self.flow.reply.state != "committed":
|
||||
self.flow.kill(self.master)
|
||||
self.state.delete_flow(self.flow)
|
||||
|
||||
@ -438,6 +438,7 @@ class Application(tornado.web.Application):
|
||||
xsrf_cookies=True,
|
||||
cookie_secret=os.urandom(256),
|
||||
debug=debug,
|
||||
autoreload=False,
|
||||
wauthenticator=wauthenticator,
|
||||
)
|
||||
super(Application, self).__init__(handlers, **settings)
|
||||
|
@ -183,7 +183,6 @@ class WebMaster(flow.FlowMaster):
|
||||
if self.state.intercept and self.state.intercept(
|
||||
f) and not f.request.is_replay:
|
||||
f.intercept(self)
|
||||
f.reply.take()
|
||||
return f
|
||||
|
||||
@controller.handler
|
||||
|
@ -12,13 +12,11 @@ class MasterTest:
|
||||
with master.handlecontext():
|
||||
func = getattr(master, handler)
|
||||
func(*message)
|
||||
if message:
|
||||
message[0].reply = controller.DummyReply()
|
||||
|
||||
def cycle(self, master, content):
|
||||
f = tutils.tflow(req=netlib.tutils.treq(content=content))
|
||||
l = proxy.Log("connect")
|
||||
l.reply = mock.MagicMock()
|
||||
l.reply = controller.DummyReply()
|
||||
master.log(l)
|
||||
self.invoke(master, "clientconnect", f.client_conn)
|
||||
self.invoke(master, "clientconnect", f.client_conn)
|
||||
|
@ -29,7 +29,7 @@ class TestConcurrent(mastertest.MasterTest):
|
||||
self.invoke(m, "request", f2)
|
||||
start = time.time()
|
||||
while time.time() - start < 5:
|
||||
if f1.reply.acked and f2.reply.acked:
|
||||
if f1.reply.state == f2.reply.state == "committed":
|
||||
return
|
||||
raise ValueError("Script never acked")
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
from test.mitmproxy import tutils
|
||||
from threading import Thread, Event
|
||||
|
||||
from mock import Mock
|
||||
@ -5,7 +6,7 @@ from mock import Mock
|
||||
from mitmproxy import controller
|
||||
from six.moves import queue
|
||||
|
||||
from mitmproxy.exceptions import Kill
|
||||
from mitmproxy.exceptions import Kill, ControlException
|
||||
from mitmproxy.proxy import DummyServer
|
||||
from netlib.tutils import raises
|
||||
|
||||
@ -55,7 +56,7 @@ class TestChannel(object):
|
||||
def test_tell(self):
|
||||
q = queue.Queue()
|
||||
channel = controller.Channel(q, Event())
|
||||
m = Mock()
|
||||
m = Mock(name="test_tell")
|
||||
channel.tell("test", m)
|
||||
assert q.get() == ("test", m)
|
||||
assert m.reply
|
||||
@ -66,12 +67,15 @@ class TestChannel(object):
|
||||
def reply():
|
||||
m, obj = q.get()
|
||||
assert m == "test"
|
||||
obj.reply.handle()
|
||||
obj.reply.send(42)
|
||||
obj.reply.take()
|
||||
obj.reply.commit()
|
||||
|
||||
Thread(target=reply).start()
|
||||
|
||||
channel = controller.Channel(q, Event())
|
||||
assert channel.ask("test", Mock()) == 42
|
||||
assert channel.ask("test", Mock(name="test_ask_simple")) == 42
|
||||
|
||||
def test_ask_shutdown(self):
|
||||
q = queue.Queue()
|
||||
@ -79,31 +83,122 @@ class TestChannel(object):
|
||||
done.set()
|
||||
channel = controller.Channel(q, done)
|
||||
with raises(Kill):
|
||||
channel.ask("test", Mock())
|
||||
|
||||
|
||||
class TestDummyReply(object):
|
||||
def test_simple(self):
|
||||
reply = controller.DummyReply()
|
||||
assert not reply.acked
|
||||
reply.ack()
|
||||
assert reply.acked
|
||||
channel.ask("test", Mock(name="test_ask_shutdown"))
|
||||
|
||||
|
||||
class TestReply(object):
|
||||
def test_simple(self):
|
||||
reply = controller.Reply(42)
|
||||
assert not reply.acked
|
||||
assert reply.state == "unhandled"
|
||||
|
||||
reply.handle()
|
||||
assert reply.state == "handled"
|
||||
|
||||
reply.send("foo")
|
||||
assert reply.acked
|
||||
assert reply.value == "foo"
|
||||
|
||||
reply.take()
|
||||
assert reply.state == "taken"
|
||||
|
||||
with tutils.raises(queue.Empty):
|
||||
reply.q.get_nowait()
|
||||
reply.commit()
|
||||
assert reply.state == "committed"
|
||||
assert reply.q.get() == "foo"
|
||||
|
||||
def test_default(self):
|
||||
reply = controller.Reply(42)
|
||||
def test_kill(self):
|
||||
reply = controller.Reply(43)
|
||||
reply.handle()
|
||||
reply.kill()
|
||||
reply.take()
|
||||
reply.commit()
|
||||
assert reply.q.get() == Kill
|
||||
|
||||
def test_ack(self):
|
||||
reply = controller.Reply(44)
|
||||
reply.handle()
|
||||
reply.ack()
|
||||
assert reply.q.get() == 42
|
||||
reply.take()
|
||||
reply.commit()
|
||||
assert reply.q.get() == 44
|
||||
|
||||
def test_reply_none(self):
|
||||
reply = controller.Reply(42)
|
||||
reply = controller.Reply(45)
|
||||
reply.handle()
|
||||
reply.send(None)
|
||||
reply.take()
|
||||
reply.commit()
|
||||
assert reply.q.get() is None
|
||||
|
||||
def test_commit_no_reply(self):
|
||||
reply = controller.Reply(46)
|
||||
reply.handle()
|
||||
reply.take()
|
||||
with tutils.raises(ControlException):
|
||||
reply.commit()
|
||||
reply.ack()
|
||||
reply.commit()
|
||||
|
||||
def test_double_send(self):
|
||||
reply = controller.Reply(47)
|
||||
reply.handle()
|
||||
reply.send(1)
|
||||
with tutils.raises(ControlException):
|
||||
reply.send(2)
|
||||
reply.take()
|
||||
reply.commit()
|
||||
|
||||
def test_state_transitions(self):
|
||||
states = {"unhandled", "handled", "taken", "committed"}
|
||||
accept = {
|
||||
"handle": {"unhandled"},
|
||||
"take": {"handled"},
|
||||
"commit": {"taken"},
|
||||
"ack": {"handled", "taken"},
|
||||
}
|
||||
for fn, ok in accept.items():
|
||||
for state in states:
|
||||
r = controller.Reply(48)
|
||||
r._state = state
|
||||
if fn == "commit":
|
||||
r.value = 49
|
||||
if state in ok:
|
||||
getattr(r, fn)()
|
||||
else:
|
||||
with tutils.raises(ControlException):
|
||||
getattr(r, fn)()
|
||||
r._state = "committed" # hide warnings on deletion
|
||||
|
||||
def test_del(self):
|
||||
reply = controller.Reply(47)
|
||||
with tutils.raises(ControlException):
|
||||
reply.__del__()
|
||||
reply.handle()
|
||||
reply.ack()
|
||||
reply.take()
|
||||
reply.commit()
|
||||
|
||||
|
||||
class TestDummyReply(object):
|
||||
def test_simple(self):
|
||||
reply = controller.DummyReply()
|
||||
for _ in range(2):
|
||||
reply.handle()
|
||||
reply.ack()
|
||||
reply.take()
|
||||
reply.commit()
|
||||
reply.reset()
|
||||
assert reply.state == "unhandled"
|
||||
|
||||
def test_reset(self):
|
||||
reply = controller.DummyReply()
|
||||
reply.handle()
|
||||
reply.ack()
|
||||
reply.take()
|
||||
reply.commit()
|
||||
reply.reset()
|
||||
assert reply.state == "unhandled"
|
||||
|
||||
def test_del(self):
|
||||
reply = controller.DummyReply()
|
||||
reply.__del__()
|
@ -375,6 +375,7 @@ class TestHTTPFlow(object):
|
||||
s = flow.State()
|
||||
fm = flow.FlowMaster(None, None, s)
|
||||
f = tutils.tflow()
|
||||
f.reply.handle()
|
||||
f.intercept(mock.Mock())
|
||||
f.kill(fm)
|
||||
for i in s.view:
|
||||
@ -385,6 +386,7 @@ class TestHTTPFlow(object):
|
||||
fm = flow.FlowMaster(None, None, s)
|
||||
|
||||
f = tutils.tflow()
|
||||
f.reply.handle()
|
||||
f.intercept(fm)
|
||||
|
||||
s.killall(fm)
|
||||
@ -393,11 +395,11 @@ class TestHTTPFlow(object):
|
||||
|
||||
def test_accept_intercept(self):
|
||||
f = tutils.tflow()
|
||||
|
||||
f.reply.handle()
|
||||
f.intercept(mock.Mock())
|
||||
assert not f.reply.acked
|
||||
assert f.reply.state == "taken"
|
||||
f.accept_intercept(mock.Mock())
|
||||
assert f.reply.acked
|
||||
assert f.reply.state == "committed"
|
||||
|
||||
def test_replace_unicode(self):
|
||||
f = tutils.tflow(resp=True)
|
||||
@ -735,7 +737,6 @@ class TestFlowMaster:
|
||||
fm.clientdisconnect(f.client_conn)
|
||||
|
||||
f.error = Error("msg")
|
||||
f.error.reply = controller.DummyReply()
|
||||
fm.error(f)
|
||||
|
||||
fm.shutdown()
|
||||
|
Loading…
Reference in New Issue
Block a user