Merge pull request #1474 from mhils/reply-fix

Improve controller.Reply semantics
This commit is contained in:
Maximilian Hils 2016-08-10 02:22:39 -07:00 committed by GitHub
commit ea2f23feff
27 changed files with 319 additions and 154 deletions

View File

@ -41,9 +41,9 @@ class Replace:
f.request.replace(rex, s) f.request.replace(rex, s)
def request(self, flow): def request(self, flow):
if not flow.reply.acked: if not flow.reply.has_message:
self.execute(flow) self.execute(flow)
def response(self, flow): def response(self, flow):
if not flow.reply.acked: if not flow.reply.has_message:
self.execute(flow) self.execute(flow)

View File

@ -31,9 +31,9 @@ class SetHeaders:
hdrs.add(header, value) hdrs.add(header, value)
def request(self, flow): def request(self, flow):
if not flow.reply.acked: if not flow.reply.has_message:
self.run(flow, flow.request.headers) self.run(flow, flow.request.headers)
def response(self, flow): def response(self, flow):
if not flow.reply.acked: if not flow.reply.has_message:
self.run(flow, flow.response.headers) self.run(flow, flow.response.headers)

View File

@ -413,7 +413,7 @@ def raw_format_flow(f, focus, extended):
def format_flow(f, focus, extended=False, hostheader=False): def format_flow(f, focus, extended=False, hostheader=False):
d = dict( d = dict(
intercepted = f.intercepted, intercepted = f.intercepted,
acked = f.reply.acked, acked = f.reply.state == "committed",
req_timestamp = f.request.timestamp_start, req_timestamp = f.request.timestamp_start,
req_is_replay = f.request.is_replay, req_is_replay = f.request.is_replay,

View File

@ -182,7 +182,7 @@ class ConnectionItem(urwid.WidgetWrap):
self.flow.accept_intercept(self.master) self.flow.accept_intercept(self.master)
signals.flowlist_change.send(self) signals.flowlist_change.send(self)
elif key == "d": elif key == "d":
if not self.flow.reply.acked: if self.flow.killable:
self.flow.kill(self.master) self.flow.kill(self.master)
self.state.delete_flow(self.flow) self.state.delete_flow(self.flow)
signals.flowlist_change.send(self) signals.flowlist_change.send(self)
@ -246,7 +246,7 @@ class ConnectionItem(urwid.WidgetWrap):
callback = self.save_flows_prompt, callback = self.save_flows_prompt,
) )
elif key == "X": elif key == "X":
if not self.flow.reply.acked: if self.flow.killable:
self.flow.kill(self.master) self.flow.kill(self.master)
elif key == "enter": elif key == "enter":
if self.flow.request: if self.flow.request:

View File

@ -8,7 +8,6 @@ import urwid
from typing import Optional, Union # noqa from typing import Optional, Union # noqa
from mitmproxy import contentviews from mitmproxy import contentviews
from mitmproxy import controller
from mitmproxy import models from mitmproxy import models
from mitmproxy import utils from mitmproxy import utils
from mitmproxy.console import common from mitmproxy.console import common
@ -148,13 +147,13 @@ class FlowView(tabs.Tabs):
signals.flow_change.connect(self.sig_flow_change) signals.flow_change.connect(self.sig_flow_change)
def tab_request(self): def tab_request(self):
if self.flow.intercepted and not self.flow.reply.acked and not self.flow.response: if self.flow.intercepted and not self.flow.response:
return "Request intercepted" return "Request intercepted"
else: else:
return "Request" return "Request"
def tab_response(self): def tab_response(self):
if self.flow.intercepted and not self.flow.reply.acked and self.flow.response: if self.flow.intercepted and self.flow.response:
return "Response intercepted" return "Response intercepted"
else: else:
return "Response" return "Response"
@ -379,7 +378,6 @@ class FlowView(tabs.Tabs):
self.flow.request.http_version, self.flow.request.http_version,
200, b"OK", Headers(), b"" 200, b"OK", Headers(), b""
) )
self.flow.response.reply = controller.DummyReply()
message = self.flow.response message = self.flow.response
self.flow.backup() self.flow.backup()
@ -538,7 +536,7 @@ class FlowView(tabs.Tabs):
else: else:
self.view_next_flow(self.flow) self.view_next_flow(self.flow)
f = self.flow f = self.flow
if not f.reply.acked: if f.killable:
f.kill(self.master) f.kill(self.master)
self.state.delete_flow(f) self.state.delete_flow(f)
elif key == "D": elif key == "D":

View File

@ -736,7 +736,6 @@ class ConsoleMaster(flow.FlowMaster):
) )
if should_intercept: if should_intercept:
f.intercept(self) f.intercept(self)
f.reply.take()
signals.flowlist_change.send(self) signals.flowlist_change.send(self)
signals.flow_change.send(self, flow = f) signals.flow_change.send(self, flow = f)

View File

@ -185,6 +185,7 @@ class Channel(object):
if g == exceptions.Kill: if g == exceptions.Kill:
raise exceptions.Kill() raise exceptions.Kill()
return g return g
m.reply._state = "committed" # suppress error message in __del__
raise exceptions.Kill() raise exceptions.Kill()
def tell(self, mtype, m): def tell(self, mtype, m):
@ -202,34 +203,47 @@ def handler(f):
if not hasattr(message, "reply"): if not hasattr(message, "reply"):
raise exceptions.ControlException("Message %s has no reply attribute" % message) raise exceptions.ControlException("Message %s has no reply attribute" % message)
# DummyReplys may be reused multiple times.
# We only clear them up on the next handler so that we can access value and
# state in the meantime.
if isinstance(message.reply, DummyReply):
message.reply.reset()
# The following ensures that inheritance with wrapped handlers in the # The following ensures that inheritance with wrapped handlers in the
# base class works. If we're the first handler, then responsibility for # base class works. If we're the first handler, then responsibility for
# acking is ours. If not, it's someone else's and we ignore it. # acking is ours. If not, it's someone else's and we ignore it.
handling = False handling = False
# We're the first handler - ack responsibility is ours # We're the first handler - ack responsibility is ours
if not message.reply.handled: if message.reply.state == "unhandled":
handling = True handling = True
message.reply.handled = True message.reply.handle()
with master.handlecontext(): with master.handlecontext():
ret = f(master, message) ret = f(master, message)
if handling: if handling:
master.addons(f.__name__, message) master.addons(f.__name__, message)
if handling and not message.reply.acked and not message.reply.taken:
message.reply.ack()
# Reset the handled flag - it's common for us to feed the same object # Reset the handled flag - it's common for us to feed the same object
# through handlers repeatedly, so we don't want this to persist across # through handlers repeatedly, so we don't want this to persist across
# calls. # calls.
if message.reply.handled: if handling and message.reply.state == "handled":
message.reply.handled = False message.reply.take()
if not message.reply.has_message:
message.reply.ack()
message.reply.commit()
# DummyReplys may be reused multiple times.
if isinstance(message.reply, DummyReply):
message.reply.mark_reset()
return ret return ret
# Mark this function as a handler wrapper # Mark this function as a handler wrapper
wrapper.__dict__["__handler"] = True wrapper.__dict__["__handler"] = True
return wrapper return wrapper
NO_REPLY = object() # special object we can distinguish from a valid "None" reply.
class Reply(object): class Reply(object):
""" """
Messages sent through a channel are decorated with a "reply" attribute. Messages sent through a channel are decorated with a "reply" attribute.
@ -238,53 +252,104 @@ class Reply(object):
""" """
def __init__(self, obj): def __init__(self, obj):
self.obj = obj self.obj = obj
self.q = queue.Queue() self.q = queue.Queue() # type: queue.Queue
# Has this message been acked?
self.acked = False
# Has the user taken responsibility for ack-ing?
self.taken = False
# Has a handler taken responsibility for ack-ing?
self.handled = False
def ack(self): self._state = "unhandled" # "unhandled" -> "handled" -> "taken" -> "committed"
self.send(self.obj) self.value = NO_REPLY # holds the reply value. May change before things are actually commited.
def kill(self): @property
self.send(exceptions.Kill) def state(self):
"""
The state the reply is currently in. A normal reply object goes sequentially through the following lifecycle:
1. unhandled: Initial State.
2. handled: The reply object has been handled by the topmost handler function.
3. taken: The reply object has been taken to be commited.
4. committed: The reply has been sent back to the requesting party.
This attribute is read-only and can only be modified by calling one of state transition functions.
"""
return self._state
@property
def has_message(self):
return self.value != NO_REPLY
@property
def done(self):
return self.state == "committed"
def handle(self):
"""
Reply are handled by controller.handlers, which may be nested. The first handler takes
responsibility and handles the reply.
"""
if self.state != "unhandled":
raise exceptions.ControlException("Reply is {}, but expected it to be unhandled.".format(self.state))
self._state = "handled"
def take(self): def take(self):
self.taken = True """
Scripts or other parties make "take" a reply out of a normal flow.
For example, intercepted flows are taken out so that the connection thread does not proceed.
"""
if self.state != "handled":
raise exceptions.ControlException("Reply is {}, but expected it to be handled.".format(self.state))
self._state = "taken"
def send(self, msg): def commit(self):
if self.acked: """
raise exceptions.ControlException("Message already acked.") Ultimately, messages are commited. This is done either automatically by the handler
self.acked = True if the message is not taken or manually by the entity which called .take().
self.q.put(msg) """
if self.state != "taken":
raise exceptions.ControlException("Reply is {}, but expected it to be taken.".format(self.state))
if not self.has_message:
raise exceptions.ControlException("There is no reply message.")
self._state = "committed"
self.q.put(self.value)
def ack(self, force=False):
self.send(self.obj, force)
def kill(self, force=False):
self.send(exceptions.Kill, force)
def send(self, msg, force=False):
if self.state not in ("handled", "taken"):
raise exceptions.ControlException(
"Reply is {}, did not expect a call to .send().".format(self.state)
)
if self.has_message and not force:
raise exceptions.ControlException("There is already a reply message.")
self.value = msg
def __del__(self): def __del__(self):
if not self.acked: if self.state != "committed":
# This will be ignored by the interpreter, but emit a warning # This will be ignored by the interpreter, but emit a warning
raise exceptions.ControlException("Un-acked message: %s" % self.obj) raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
class DummyReply(object): class DummyReply(Reply):
""" """
A reply object that does nothing. Useful when we need an object to seem A reply object that is not connected to anything. In contrast to regular Reply objects,
like it has a channel, and during testing. DummyReply objects are reset to "unhandled" at the end of an handler so that they can be used
multiple times. Useful when we need an object to seem like it has a channel,
and during testing.
""" """
def __init__(self): def __init__(self):
self.acked = False super(DummyReply, self).__init__(None)
self.taken = False self._should_reset = False
self.handled = False
def kill(self): def mark_reset(self):
self.send(None) if self.state != "committed":
raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
self._should_reset = True
def ack(self): def reset(self):
self.send(None) if self._should_reset:
self._state = "unhandled"
self.value = NO_REPLY
def take(self): def __del__(self):
self.taken = True pass
def send(self, msg):
self.acked = True

View File

@ -234,7 +234,7 @@ class FlowMaster(controller.Master):
pb = self.do_server_playback(f) pb = self.do_server_playback(f)
if not pb and self.kill_nonreplay: if not pb and self.kill_nonreplay:
self.add_log("Killed {}".format(f.request.url), "info") self.add_log("Killed {}".format(f.request.url), "info")
f.kill(self) f.reply.kill()
def replay_request(self, f, block=False): def replay_request(self, f, block=False):
""" """
@ -314,8 +314,7 @@ class FlowMaster(controller.Master):
return return
if f not in self.state.flows: # don't add again on replay if f not in self.state.flows: # don't add again on replay
self.state.add_flow(f) self.state.add_flow(f)
if not f.reply.acked: self.process_new_request(f)
self.process_new_request(f)
return f return f
@controller.handler @controller.handler
@ -331,9 +330,8 @@ class FlowMaster(controller.Master):
@controller.handler @controller.handler
def response(self, f): def response(self, f):
self.state.update_flow(f) self.state.update_flow(f)
if not f.reply.acked: if self.client_playback:
if self.client_playback: self.client_playback.clear(f)
self.client_playback.clear(f)
return f return f
def handle_intercept(self, f): def handle_intercept(self, f):

View File

@ -178,7 +178,7 @@ class FlowStore(FlowList):
def kill_all(self, master): def kill_all(self, master):
for f in self._list: for f in self._list:
if not f.reply.acked: if f.killable:
f.kill(master) f.kill(master)

View File

@ -149,13 +149,22 @@ class Flow(stateobject.StateObject):
self.set_state(self._backup) self.set_state(self._backup)
self._backup = None self._backup = None
@property
def killable(self):
return self.reply and self.reply.state in {"handled", "taken"}
def kill(self, master): def kill(self, master):
""" """
Kill this request. Kill this request.
""" """
self.error = Error("Connection killed") self.error = Error("Connection killed")
self.intercepted = False self.intercepted = False
self.reply.kill() # reply.state should only be "handled" or "taken" here.
# if none of this is the case, .take() will raise an exception.
if self.reply.state != "taken":
self.reply.take()
self.reply.kill(force=True)
self.reply.commit()
master.error(self) master.error(self)
def intercept(self, master): def intercept(self, master):
@ -166,6 +175,7 @@ class Flow(stateobject.StateObject):
if self.intercepted: if self.intercepted:
return return
self.intercepted = True self.intercepted = True
self.reply.take()
master.handle_intercept(self) master.handle_intercept(self)
def accept_intercept(self, master): def accept_intercept(self, master):
@ -176,6 +186,7 @@ class Flow(stateobject.StateObject):
return return
self.intercepted = False self.intercepted = False
self.reply.ack() self.reply.ack()
self.reply.commit()
master.handle_accept_intercept(self) master.handle_accept_intercept(self)
def match(self, f): def match(self, f):

View File

@ -13,7 +13,7 @@ class ScriptThread(basethread.BaseThread):
def concurrent(fn): def concurrent(fn):
if fn.__name__ not in controller.Events - set(["start", "configure", "tick"]): if fn.__name__ not in controller.Events - {"start", "configure", "tick"}:
raise NotImplementedError( raise NotImplementedError(
"Concurrent decorator not supported for '%s' method." % fn.__name__ "Concurrent decorator not supported for '%s' method." % fn.__name__
) )
@ -21,8 +21,10 @@ def concurrent(fn):
def _concurrent(obj): def _concurrent(obj):
def run(): def run():
fn(obj) fn(obj)
if not obj.reply.acked: if obj.reply.state == "taken":
obj.reply.ack() if not obj.reply.has_message:
obj.reply.ack()
obj.reply.commit()
obj.reply.take() obj.reply.take()
ScriptThread( ScriptThread(
"script.concurrent (%s)" % fn.__name__, "script.concurrent (%s)" % fn.__name__,

View File

@ -234,7 +234,7 @@ class AcceptFlow(RequestHandler):
class FlowHandler(RequestHandler): class FlowHandler(RequestHandler):
def delete(self, flow_id): def delete(self, flow_id):
if not self.flow.reply.acked: if self.flow.killable:
self.flow.kill(self.master) self.flow.kill(self.master)
self.state.delete_flow(self.flow) self.state.delete_flow(self.flow)
@ -438,6 +438,7 @@ class Application(tornado.web.Application):
xsrf_cookies=True, xsrf_cookies=True,
cookie_secret=os.urandom(256), cookie_secret=os.urandom(256),
debug=debug, debug=debug,
autoreload=False,
wauthenticator=wauthenticator, wauthenticator=wauthenticator,
) )
super(Application, self).__init__(handlers, **settings) super(Application, self).__init__(handlers, **settings)

View File

@ -183,7 +183,6 @@ class WebMaster(flow.FlowMaster):
if self.state.intercept and self.state.intercept( if self.state.intercept and self.state.intercept(
f) and not f.request.is_replay: f) and not f.request.is_replay:
f.intercept(self) f.intercept(self)
f.reply.take()
return f return f
@controller.handler @controller.handler

View File

@ -14,11 +14,11 @@ class TestAntiCache(mastertest.MasterTest):
m.addons.add(o, sa) m.addons.add(o, sa)
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
self.invoke(m, "request", f) m.request(f)
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
f.request.headers["if-modified-since"] = "test" f.request.headers["if-modified-since"] = "test"
f.request.headers["if-none-match"] = "test" f.request.headers["if-none-match"] = "test"
self.invoke(m, "request", f) m.request(f)
assert "if-modified-since" not in f.request.headers assert "if-modified-since" not in f.request.headers
assert "if-none-match" not in f.request.headers assert "if-none-match" not in f.request.headers

View File

@ -14,10 +14,10 @@ class TestAntiComp(mastertest.MasterTest):
m.addons.add(o, sa) m.addons.add(o, sa)
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
self.invoke(m, "request", f) m.request(f)
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
f.request.headers["Accept-Encoding"] = "foobar" f.request.headers["Accept-Encoding"] = "foobar"
self.invoke(m, "request", f) m.request(f)
assert f.request.headers["Accept-Encoding"] == "identity" assert f.request.headers["Accept-Encoding"] == "identity"

View File

@ -80,5 +80,5 @@ class TestContentView(mastertest.MasterTest):
m = mastertest.RecordingMaster(o, None, s) m = mastertest.RecordingMaster(o, None, s)
d = dumper.Dumper() d = dumper.Dumper()
m.addons.add(o, d) m.addons.add(o, d)
self.invoke(m, "response", tutils.tflow()) m.response(tutils.tflow())
assert "Content viewer failed" in m.event_log[0][1] assert "Content viewer failed" in m.event_log[0][1]

View File

@ -28,8 +28,8 @@ class TestStream(mastertest.MasterTest):
m.addons.add(o, sa) m.addons.add(o, sa)
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
self.invoke(m, "request", f) m.request(f)
self.invoke(m, "response", f) m.response(f)
m.addons.remove(sa) m.addons.remove(sa)
assert r()[0].response assert r()[0].response
@ -38,6 +38,6 @@ class TestStream(mastertest.MasterTest):
m.addons.add(o, sa) m.addons.add(o, sa)
f = tutils.tflow() f = tutils.tflow()
self.invoke(m, "request", f) m.request(f)
m.addons.remove(sa) m.addons.remove(sa)
assert not r()[1].response assert not r()[1].response

View File

@ -43,10 +43,10 @@ class TestReplace(mastertest.MasterTest):
f = tutils.tflow() f = tutils.tflow()
f.request.content = b"foo" f.request.content = b"foo"
self.invoke(m, "request", f) m.request(f)
assert f.request.content == b"bar" assert f.request.content == b"bar"
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
f.response.content = b"foo" f.response.content = b"foo"
self.invoke(m, "response", f) m.response(f)
assert f.response.content == b"bar" assert f.response.content == b"bar"

View File

@ -69,7 +69,7 @@ class TestScript(mastertest.MasterTest):
sc.ns.call_log = [] sc.ns.call_log = []
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
self.invoke(m, "request", f) m.request(f)
recf = sc.ns.call_log[0] recf = sc.ns.call_log[0]
assert recf[1] == "request" assert recf[1] == "request"
@ -102,7 +102,7 @@ class TestScript(mastertest.MasterTest):
) )
m.addons.add(o, sc) m.addons.add(o, sc)
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
self.invoke(m, "request", f) m.request(f)
assert m.event_log[0][0] == "error" assert m.event_log[0][0] == "error"
def test_duplicate_flow(self): def test_duplicate_flow(self):

View File

@ -33,12 +33,12 @@ class TestSetHeaders(mastertest.MasterTest):
) )
f = tutils.tflow() f = tutils.tflow()
f.request.headers["one"] = "xxx" f.request.headers["one"] = "xxx"
self.invoke(m, "request", f) m.request(f)
assert f.request.headers["one"] == "two" assert f.request.headers["one"] == "two"
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
f.response.headers["one"] = "xxx" f.response.headers["one"] = "xxx"
self.invoke(m, "response", f) m.response(f)
assert f.response.headers["one"] == "three" assert f.response.headers["one"] == "three"
m, sh = self.mkmaster( m, sh = self.mkmaster(
@ -50,7 +50,7 @@ class TestSetHeaders(mastertest.MasterTest):
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
f.request.headers["one"] = "xxx" f.request.headers["one"] = "xxx"
f.response.headers["one"] = "xxx" f.response.headers["one"] = "xxx"
self.invoke(m, "response", f) m.response(f)
assert f.response.headers.get_all("one") == ["two", "three"] assert f.response.headers.get_all("one") == ["two", "three"]
m, sh = self.mkmaster( m, sh = self.mkmaster(
@ -61,5 +61,5 @@ class TestSetHeaders(mastertest.MasterTest):
) )
f = tutils.tflow() f = tutils.tflow()
f.request.headers["one"] = "xxx" f.request.headers["one"] = "xxx"
self.invoke(m, "request", f) m.request(f)
assert f.request.headers.get_all("one") == ["two", "three"] assert f.request.headers.get_all("one") == ["two", "three"]

View File

@ -15,10 +15,10 @@ class TestStickyAuth(mastertest.MasterTest):
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
f.request.headers["authorization"] = "foo" f.request.headers["authorization"] = "foo"
self.invoke(m, "request", f) m.request(f)
assert "address" in sa.hosts assert "address" in sa.hosts
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
self.invoke(m, "request", f) m.request(f)
assert f.request.headers["authorization"] == "foo" assert f.request.headers["authorization"] == "foo"

View File

@ -34,23 +34,23 @@ class TestStickyCookie(mastertest.MasterTest):
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
f.response.headers["set-cookie"] = "foo=bar" f.response.headers["set-cookie"] = "foo=bar"
self.invoke(m, "request", f) m.request(f)
f.reply.acked = False f.reply.acked = False
self.invoke(m, "response", f) m.response(f)
assert sc.jar assert sc.jar
assert "cookie" not in f.request.headers assert "cookie" not in f.request.headers
f = f.copy() f = f.copy()
f.reply.acked = False f.reply.acked = False
self.invoke(m, "request", f) m.request(f)
assert f.request.headers["cookie"] == "foo=bar" assert f.request.headers["cookie"] == "foo=bar"
def _response(self, s, m, sc, cookie, host): def _response(self, s, m, sc, cookie, host):
f = tutils.tflow(req=ntutils.treq(host=host, port=80), resp=True) f = tutils.tflow(req=ntutils.treq(host=host, port=80), resp=True)
f.response.headers["Set-Cookie"] = cookie f.response.headers["Set-Cookie"] = cookie
self.invoke(m, "response", f) m.response(f)
return f return f
def test_response(self): def test_response(self):
@ -79,7 +79,7 @@ class TestStickyCookie(mastertest.MasterTest):
c2 = "othercookie=helloworld; Path=/" c2 = "othercookie=helloworld; Path=/"
f = self._response(s, m, sc, c1, "www.google.com") f = self._response(s, m, sc, c1, "www.google.com")
f.response.headers["Set-Cookie"] = c2 f.response.headers["Set-Cookie"] = c2
self.invoke(m, "response", f) m.response(f)
googlekey = list(sc.jar.keys())[0] googlekey = list(sc.jar.keys())[0]
assert len(sc.jar[googlekey].keys()) == 2 assert len(sc.jar[googlekey].keys()) == 2
@ -96,7 +96,7 @@ class TestStickyCookie(mastertest.MasterTest):
] ]
for c in cs: for c in cs:
f.response.headers["Set-Cookie"] = c f.response.headers["Set-Cookie"] = c
self.invoke(m, "response", f) m.response(f)
googlekey = list(sc.jar.keys())[0] googlekey = list(sc.jar.keys())[0]
assert len(sc.jar[googlekey].keys()) == len(cs) assert len(sc.jar[googlekey].keys()) == len(cs)
@ -108,7 +108,7 @@ class TestStickyCookie(mastertest.MasterTest):
c2 = "somecookie=newvalue; Path=/" c2 = "somecookie=newvalue; Path=/"
f = self._response(s, m, sc, c1, "www.google.com") f = self._response(s, m, sc, c1, "www.google.com")
f.response.headers["Set-Cookie"] = c2 f.response.headers["Set-Cookie"] = c2
self.invoke(m, "response", f) m.response(f)
googlekey = list(sc.jar.keys())[0] googlekey = list(sc.jar.keys())[0]
assert len(sc.jar[googlekey].keys()) == 1 assert len(sc.jar[googlekey].keys()) == 1
assert list(sc.jar[googlekey]["somecookie"].items())[0][1] == "newvalue" assert list(sc.jar[googlekey]["somecookie"].items())[0][1] == "newvalue"
@ -120,7 +120,7 @@ class TestStickyCookie(mastertest.MasterTest):
# by setting the expire time in the past # by setting the expire time in the past
f = self._response(s, m, sc, "duffer=zafar; Path=/", "www.google.com") f = self._response(s, m, sc, "duffer=zafar; Path=/", "www.google.com")
f.response.headers["Set-Cookie"] = "duffer=; Expires=Thu, 01-Jan-1970 00:00:00 GMT" f.response.headers["Set-Cookie"] = "duffer=; Expires=Thu, 01-Jan-1970 00:00:00 GMT"
self.invoke(m, "response", f) m.response(f)
assert not sc.jar.keys() assert not sc.jar.keys()
def test_request(self): def test_request(self):
@ -128,5 +128,5 @@ class TestStickyCookie(mastertest.MasterTest):
f = self._response(s, m, sc, "SSID=mooo", "www.google.com") f = self._response(s, m, sc, "SSID=mooo", "www.google.com")
assert "cookie" not in f.request.headers assert "cookie" not in f.request.headers
self.invoke(m, "request", f) m.request(f)
assert "cookie" in f.request.headers assert "cookie" in f.request.headers

View File

@ -1,5 +1,3 @@
import mock
from . import tutils from . import tutils
import netlib.tutils import netlib.tutils
@ -8,26 +6,19 @@ from mitmproxy import flow, proxy, models, controller
class MasterTest: class MasterTest:
def invoke(self, master, handler, *message):
with master.handlecontext():
func = getattr(master, handler)
func(*message)
if message:
message[0].reply = controller.DummyReply()
def cycle(self, master, content): def cycle(self, master, content):
f = tutils.tflow(req=netlib.tutils.treq(content=content)) f = tutils.tflow(req=netlib.tutils.treq(content=content))
l = proxy.Log("connect") l = proxy.Log("connect")
l.reply = mock.MagicMock() l.reply = controller.DummyReply()
master.log(l) master.log(l)
self.invoke(master, "clientconnect", f.client_conn) master.clientconnect(f.client_conn)
self.invoke(master, "clientconnect", f.client_conn) master.serverconnect(f.server_conn)
self.invoke(master, "serverconnect", f.server_conn) master.request(f)
self.invoke(master, "request", f)
if not f.error: if not f.error:
f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content)) f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content))
self.invoke(master, "response", f) master.response(f)
self.invoke(master, "clientdisconnect", f) master.clientdisconnect(f)
return f return f
def dummy_cycle(self, master, n, content): def dummy_cycle(self, master, n, content):

View File

@ -25,11 +25,11 @@ class TestConcurrent(mastertest.MasterTest):
) )
m.addons.add(m.options, sc) m.addons.add(m.options, sc)
f1, f2 = tutils.tflow(), tutils.tflow() f1, f2 = tutils.tflow(), tutils.tflow()
self.invoke(m, "request", f1) m.request(f1)
self.invoke(m, "request", f2) m.request(f2)
start = time.time() start = time.time()
while time.time() - start < 5: while time.time() - start < 5:
if f1.reply.acked and f2.reply.acked: if f1.reply.state == f2.reply.state == "committed":
return return
raise ValueError("Script never acked") raise ValueError("Script never acked")

View File

@ -1,3 +1,4 @@
from test.mitmproxy import tutils
from threading import Thread, Event from threading import Thread, Event
from mock import Mock from mock import Mock
@ -5,7 +6,7 @@ from mock import Mock
from mitmproxy import controller from mitmproxy import controller
from six.moves import queue from six.moves import queue
from mitmproxy.exceptions import Kill from mitmproxy.exceptions import Kill, ControlException
from mitmproxy.proxy import DummyServer from mitmproxy.proxy import DummyServer
from netlib.tutils import raises from netlib.tutils import raises
@ -55,7 +56,7 @@ class TestChannel(object):
def test_tell(self): def test_tell(self):
q = queue.Queue() q = queue.Queue()
channel = controller.Channel(q, Event()) channel = controller.Channel(q, Event())
m = Mock() m = Mock(name="test_tell")
channel.tell("test", m) channel.tell("test", m)
assert q.get() == ("test", m) assert q.get() == ("test", m)
assert m.reply assert m.reply
@ -66,12 +67,15 @@ class TestChannel(object):
def reply(): def reply():
m, obj = q.get() m, obj = q.get()
assert m == "test" assert m == "test"
obj.reply.handle()
obj.reply.send(42) obj.reply.send(42)
obj.reply.take()
obj.reply.commit()
Thread(target=reply).start() Thread(target=reply).start()
channel = controller.Channel(q, Event()) channel = controller.Channel(q, Event())
assert channel.ask("test", Mock()) == 42 assert channel.ask("test", Mock(name="test_ask_simple")) == 42
def test_ask_shutdown(self): def test_ask_shutdown(self):
q = queue.Queue() q = queue.Queue()
@ -79,31 +83,125 @@ class TestChannel(object):
done.set() done.set()
channel = controller.Channel(q, done) channel = controller.Channel(q, done)
with raises(Kill): with raises(Kill):
channel.ask("test", Mock()) channel.ask("test", Mock(name="test_ask_shutdown"))
class TestDummyReply(object):
def test_simple(self):
reply = controller.DummyReply()
assert not reply.acked
reply.ack()
assert reply.acked
class TestReply(object): class TestReply(object):
def test_simple(self): def test_simple(self):
reply = controller.Reply(42) reply = controller.Reply(42)
assert not reply.acked assert reply.state == "unhandled"
reply.handle()
assert reply.state == "handled"
reply.send("foo") reply.send("foo")
assert reply.acked assert reply.value == "foo"
reply.take()
assert reply.state == "taken"
with tutils.raises(queue.Empty):
reply.q.get_nowait()
reply.commit()
assert reply.state == "committed"
assert reply.q.get() == "foo" assert reply.q.get() == "foo"
def test_default(self): def test_kill(self):
reply = controller.Reply(42) reply = controller.Reply(43)
reply.handle()
reply.kill()
reply.take()
reply.commit()
assert reply.q.get() == Kill
def test_ack(self):
reply = controller.Reply(44)
reply.handle()
reply.ack() reply.ack()
assert reply.q.get() == 42 reply.take()
reply.commit()
assert reply.q.get() == 44
def test_reply_none(self): def test_reply_none(self):
reply = controller.Reply(42) reply = controller.Reply(45)
reply.handle()
reply.send(None) reply.send(None)
reply.take()
reply.commit()
assert reply.q.get() is None assert reply.q.get() is None
def test_commit_no_reply(self):
reply = controller.Reply(46)
reply.handle()
reply.take()
with tutils.raises(ControlException):
reply.commit()
reply.ack()
reply.commit()
def test_double_send(self):
reply = controller.Reply(47)
reply.handle()
reply.send(1)
with tutils.raises(ControlException):
reply.send(2)
reply.take()
reply.commit()
def test_state_transitions(self):
states = {"unhandled", "handled", "taken", "committed"}
accept = {
"handle": {"unhandled"},
"take": {"handled"},
"commit": {"taken"},
"ack": {"handled", "taken"},
}
for fn, ok in accept.items():
for state in states:
r = controller.Reply(48)
r._state = state
if fn == "commit":
r.value = 49
if state in ok:
getattr(r, fn)()
else:
with tutils.raises(ControlException):
getattr(r, fn)()
r._state = "committed" # hide warnings on deletion
def test_del(self):
reply = controller.Reply(47)
with tutils.raises(ControlException):
reply.__del__()
reply.handle()
reply.ack()
reply.take()
reply.commit()
class TestDummyReply(object):
def test_simple(self):
reply = controller.DummyReply()
for _ in range(2):
reply.handle()
reply.ack()
reply.take()
reply.commit()
reply.mark_reset()
reply.reset()
assert reply.state == "unhandled"
def test_reset(self):
reply = controller.DummyReply()
reply.handle()
reply.ack()
reply.take()
reply.commit()
reply.mark_reset()
assert reply.state == "committed"
reply.reset()
assert reply.state == "unhandled"
def test_del(self):
reply = controller.DummyReply()
reply.__del__()

View File

@ -39,7 +39,7 @@ class TestScripts(mastertest.MasterTest):
def test_add_header(self): def test_add_header(self):
m, _ = tscript("add_header.py") m, _ = tscript("add_header.py")
f = tutils.tflow(resp=netutils.tresp()) f = tutils.tflow(resp=netutils.tresp())
self.invoke(m, "response", f) m.response(f)
assert f.response.headers["newheader"] == "foo" assert f.response.headers["newheader"] == "foo"
def test_custom_contentviews(self): def test_custom_contentviews(self):
@ -54,9 +54,9 @@ class TestScripts(mastertest.MasterTest):
tscript("iframe_injector.py") tscript("iframe_injector.py")
m, sc = tscript("iframe_injector.py", "http://example.org/evil_iframe") m, sc = tscript("iframe_injector.py", "http://example.org/evil_iframe")
flow = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>")) f = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>"))
self.invoke(m, "response", flow) m.response(f)
content = flow.response.content content = f.response.content
assert b'iframe' in content and b'evil_iframe' in content assert b'iframe' in content and b'evil_iframe' in content
def test_modify_form(self): def test_modify_form(self):
@ -64,23 +64,23 @@ class TestScripts(mastertest.MasterTest):
form_header = Headers(content_type="application/x-www-form-urlencoded") form_header = Headers(content_type="application/x-www-form-urlencoded")
f = tutils.tflow(req=netutils.treq(headers=form_header)) f = tutils.tflow(req=netutils.treq(headers=form_header))
self.invoke(m, "request", f) m.request(f)
assert f.request.urlencoded_form[b"mitmproxy"] == b"rocks" assert f.request.urlencoded_form[b"mitmproxy"] == b"rocks"
f.request.headers["content-type"] = "" f.request.headers["content-type"] = ""
self.invoke(m, "request", f) m.request(f)
assert list(f.request.urlencoded_form.items()) == [(b"foo", b"bar")] assert list(f.request.urlencoded_form.items()) == [(b"foo", b"bar")]
def test_modify_querystring(self): def test_modify_querystring(self):
m, sc = tscript("modify_querystring.py") m, sc = tscript("modify_querystring.py")
f = tutils.tflow(req=netutils.treq(path="/search?q=term")) f = tutils.tflow(req=netutils.treq(path="/search?q=term"))
self.invoke(m, "request", f) m.request(f)
assert f.request.query["mitmproxy"] == "rocks" assert f.request.query["mitmproxy"] == "rocks"
f.request.path = "/" f.request.path = "/"
self.invoke(m, "request", f) m.request(f)
assert f.request.query["mitmproxy"] == "rocks" assert f.request.query["mitmproxy"] == "rocks"
def test_modify_response_body(self): def test_modify_response_body(self):
@ -89,13 +89,13 @@ class TestScripts(mastertest.MasterTest):
m, sc = tscript("modify_response_body.py", "mitmproxy rocks") m, sc = tscript("modify_response_body.py", "mitmproxy rocks")
f = tutils.tflow(resp=netutils.tresp(content=b"I <3 mitmproxy")) f = tutils.tflow(resp=netutils.tresp(content=b"I <3 mitmproxy"))
self.invoke(m, "response", f) m.response(f)
assert f.response.content == b"I <3 rocks" assert f.response.content == b"I <3 rocks"
def test_redirect_requests(self): def test_redirect_requests(self):
m, sc = tscript("redirect_requests.py") m, sc = tscript("redirect_requests.py")
f = tutils.tflow(req=netutils.treq(host="example.org")) f = tutils.tflow(req=netutils.treq(host="example.org"))
self.invoke(m, "request", f) m.request(f)
assert f.request.host == "mitmproxy.org" assert f.request.host == "mitmproxy.org"
def test_har_extractor(self): def test_har_extractor(self):
@ -119,7 +119,7 @@ class TestScripts(mastertest.MasterTest):
req=netutils.treq(**times), req=netutils.treq(**times),
resp=netutils.tresp(**times) resp=netutils.tresp(**times)
) )
self.invoke(m, "response", f) m.response(f)
m.addons.remove(sc) m.addons.remove(sc)
with open(path, "rb") as f: with open(path, "rb") as f:

View File

@ -3,9 +3,9 @@ import io
import netlib.utils import netlib.utils
from netlib.http import Headers from netlib.http import Headers
from mitmproxy import filt, controller, flow, options from mitmproxy import filt, flow, options
from mitmproxy.contrib import tnetstring from mitmproxy.contrib import tnetstring
from mitmproxy.exceptions import FlowReadException from mitmproxy.exceptions import FlowReadException, Kill
from mitmproxy.models import Error from mitmproxy.models import Error
from mitmproxy.models import Flow from mitmproxy.models import Flow
from mitmproxy.models import HTTPFlow from mitmproxy.models import HTTPFlow
@ -372,19 +372,23 @@ class TestHTTPFlow(object):
assert f.get_state() == f2.get_state() assert f.get_state() == f2.get_state()
def test_kill(self): def test_kill(self):
s = flow.State() fm = mock.Mock()
fm = flow.FlowMaster(None, None, s)
f = tutils.tflow() f = tutils.tflow()
f.intercept(mock.Mock()) f.reply.handle()
f.intercept(fm)
assert fm.handle_intercept.called
assert f.killable
f.kill(fm) f.kill(fm)
for i in s.view: assert not f.killable
assert "killed" in str(i.error) assert fm.error.called
assert f.reply.value == Kill
def test_killall(self): def test_killall(self):
s = flow.State() s = flow.State()
fm = flow.FlowMaster(None, None, s) fm = flow.FlowMaster(None, None, s)
f = tutils.tflow() f = tutils.tflow()
f.reply.handle()
f.intercept(fm) f.intercept(fm)
s.killall(fm) s.killall(fm)
@ -393,11 +397,11 @@ class TestHTTPFlow(object):
def test_accept_intercept(self): def test_accept_intercept(self):
f = tutils.tflow() f = tutils.tflow()
f.reply.handle()
f.intercept(mock.Mock()) f.intercept(mock.Mock())
assert not f.reply.acked assert f.reply.state == "taken"
f.accept_intercept(mock.Mock()) f.accept_intercept(mock.Mock())
assert f.reply.acked assert f.reply.state == "committed"
def test_replace_unicode(self): def test_replace_unicode(self):
f = tutils.tflow(resp=True) f = tutils.tflow(resp=True)
@ -735,7 +739,6 @@ class TestFlowMaster:
fm.clientdisconnect(f.client_conn) fm.clientdisconnect(f.client_conn)
f.error = Error("msg") f.error = Error("msg")
f.error.reply = controller.DummyReply()
fm.error(f) fm.error(f)
fm.shutdown() fm.shutdown()
@ -834,8 +837,8 @@ class TestFlowMaster:
f = tutils.tflow() f = tutils.tflow()
f.request.host = "nonexistent" f.request.host = "nonexistent"
fm.process_new_request(f) fm.request(f)
assert "killed" in f.error.msg assert f.reply.value == Kill
class TestRequest: class TestRequest: