clean up code, improve DummyReply

This commit is contained in:
Maximilian Hils 2016-08-09 22:29:07 -07:00
parent 818840f553
commit 5a22496ee8
21 changed files with 90 additions and 74 deletions

View File

@ -182,7 +182,7 @@ class ConnectionItem(urwid.WidgetWrap):
self.flow.accept_intercept(self.master)
signals.flowlist_change.send(self)
elif key == "d":
if self.flow.reply and self.flow.reply.state != "committed":
if self.flow.killable:
self.flow.kill(self.master)
self.state.delete_flow(self.flow)
signals.flowlist_change.send(self)
@ -246,7 +246,7 @@ class ConnectionItem(urwid.WidgetWrap):
callback = self.save_flows_prompt,
)
elif key == "X":
if self.flow.reply and self.flow.reply.state != "committed":
if self.flow.killable:
self.flow.kill(self.master)
elif key == "enter":
if self.flow.request:

View File

@ -8,7 +8,6 @@ import urwid
from typing import Optional, Union # noqa
from mitmproxy import contentviews
from mitmproxy import controller
from mitmproxy import models
from mitmproxy import utils
from mitmproxy.console import common
@ -537,7 +536,7 @@ class FlowView(tabs.Tabs):
else:
self.view_next_flow(self.flow)
f = self.flow
if not f.reply.acked:
if f.killable:
f.kill(self.master)
self.state.delete_flow(f)
elif key == "D":

View File

@ -203,6 +203,12 @@ def handler(f):
if not hasattr(message, "reply"):
raise exceptions.ControlException("Message %s has no reply attribute" % message)
# DummyReplys may be reused multiple times.
# We only clear them up on the next handler so that we can access value and
# state in the meantime.
if isinstance(message.reply, DummyReply):
message.reply.reset()
# The following ensures that inheritance with wrapped handlers in the
# base class works. If we're the first handler, then responsibility for
# acking is ours. If not, it's someone else's and we ignore it.
@ -228,7 +234,7 @@ def handler(f):
# DummyReplys may be reused multiple times.
if isinstance(message.reply, DummyReply):
message.reply.reset()
message.reply.mark_reset()
return ret
# Mark this function as a handler wrapper
wrapper.__dict__["__handler"] = True
@ -269,6 +275,10 @@ class Reply(object):
def has_message(self):
return self.value != NO_REPLY
@property
def done(self):
return self.state == "committed"
def handle(self):
"""
Reply are handled by controller.handlers, which may be nested. The first handler takes
@ -329,12 +339,17 @@ class DummyReply(Reply):
"""
def __init__(self):
super(DummyReply, self).__init__(None)
self._should_reset = False
def reset(self):
def mark_reset(self):
if self.state != "committed":
raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
self._state = "unhandled"
self.value = NO_REPLY
self._should_reset = True
def reset(self):
if self._should_reset:
self._state = "unhandled"
self.value = NO_REPLY
def __del__(self):
pass

View File

@ -234,7 +234,7 @@ class FlowMaster(controller.Master):
pb = self.do_server_playback(f)
if not pb and self.kill_nonreplay:
self.add_log("Killed {}".format(f.request.url), "info")
f.kill(self)
f.reply.kill()
def replay_request(self, f, block=False):
"""

View File

@ -178,7 +178,7 @@ class FlowStore(FlowList):
def kill_all(self, master):
for f in self._list:
if f.reply.state != "committed":
if f.killable:
f.kill(master)

View File

@ -149,13 +149,17 @@ class Flow(stateobject.StateObject):
self.set_state(self._backup)
self._backup = None
@property
def killable(self):
return self.reply and self.reply.state in {"handled", "taken"}
def kill(self, master):
"""
Kill this request.
"""
self.error = Error("Connection killed")
self.intercepted = False
# reply.state should only be handled or taken here.
# reply.state should only be "handled" or "taken" here.
# if none of this is the case, .take() will raise an exception.
if self.reply.state != "taken":
self.reply.take()

View File

@ -234,7 +234,7 @@ class AcceptFlow(RequestHandler):
class FlowHandler(RequestHandler):
def delete(self, flow_id):
if self.flow.reply.state != "committed":
if self.flow.killable:
self.flow.kill(self.master)
self.state.delete_flow(self.flow)

View File

@ -14,11 +14,11 @@ class TestAntiCache(mastertest.MasterTest):
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
m.request(f)
f = tutils.tflow(resp=True)
f.request.headers["if-modified-since"] = "test"
f.request.headers["if-none-match"] = "test"
self.invoke(m, "request", f)
m.request(f)
assert "if-modified-since" not in f.request.headers
assert "if-none-match" not in f.request.headers

View File

@ -14,10 +14,10 @@ class TestAntiComp(mastertest.MasterTest):
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
m.request(f)
f = tutils.tflow(resp=True)
f.request.headers["Accept-Encoding"] = "foobar"
self.invoke(m, "request", f)
m.request(f)
assert f.request.headers["Accept-Encoding"] == "identity"

View File

@ -80,5 +80,5 @@ class TestContentView(mastertest.MasterTest):
m = mastertest.RecordingMaster(o, None, s)
d = dumper.Dumper()
m.addons.add(o, d)
self.invoke(m, "response", tutils.tflow())
m.response(tutils.tflow())
assert "Content viewer failed" in m.event_log[0][1]

View File

@ -28,8 +28,8 @@ class TestStream(mastertest.MasterTest):
m.addons.add(o, sa)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
self.invoke(m, "response", f)
m.request(f)
m.response(f)
m.addons.remove(sa)
assert r()[0].response
@ -38,6 +38,6 @@ class TestStream(mastertest.MasterTest):
m.addons.add(o, sa)
f = tutils.tflow()
self.invoke(m, "request", f)
m.request(f)
m.addons.remove(sa)
assert not r()[1].response

View File

@ -43,10 +43,10 @@ class TestReplace(mastertest.MasterTest):
f = tutils.tflow()
f.request.content = b"foo"
self.invoke(m, "request", f)
m.request(f)
assert f.request.content == b"bar"
f = tutils.tflow(resp=True)
f.response.content = b"foo"
self.invoke(m, "response", f)
m.response(f)
assert f.response.content == b"bar"

View File

@ -69,7 +69,7 @@ class TestScript(mastertest.MasterTest):
sc.ns.call_log = []
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
m.request(f)
recf = sc.ns.call_log[0]
assert recf[1] == "request"
@ -102,7 +102,7 @@ class TestScript(mastertest.MasterTest):
)
m.addons.add(o, sc)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
m.request(f)
assert m.event_log[0][0] == "error"
def test_duplicate_flow(self):

View File

@ -33,12 +33,12 @@ class TestSetHeaders(mastertest.MasterTest):
)
f = tutils.tflow()
f.request.headers["one"] = "xxx"
self.invoke(m, "request", f)
m.request(f)
assert f.request.headers["one"] == "two"
f = tutils.tflow(resp=True)
f.response.headers["one"] = "xxx"
self.invoke(m, "response", f)
m.response(f)
assert f.response.headers["one"] == "three"
m, sh = self.mkmaster(
@ -50,7 +50,7 @@ class TestSetHeaders(mastertest.MasterTest):
f = tutils.tflow(resp=True)
f.request.headers["one"] = "xxx"
f.response.headers["one"] = "xxx"
self.invoke(m, "response", f)
m.response(f)
assert f.response.headers.get_all("one") == ["two", "three"]
m, sh = self.mkmaster(
@ -61,5 +61,5 @@ class TestSetHeaders(mastertest.MasterTest):
)
f = tutils.tflow()
f.request.headers["one"] = "xxx"
self.invoke(m, "request", f)
m.request(f)
assert f.request.headers.get_all("one") == ["two", "three"]

View File

@ -15,10 +15,10 @@ class TestStickyAuth(mastertest.MasterTest):
f = tutils.tflow(resp=True)
f.request.headers["authorization"] = "foo"
self.invoke(m, "request", f)
m.request(f)
assert "address" in sa.hosts
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
m.request(f)
assert f.request.headers["authorization"] == "foo"

View File

@ -34,23 +34,23 @@ class TestStickyCookie(mastertest.MasterTest):
f = tutils.tflow(resp=True)
f.response.headers["set-cookie"] = "foo=bar"
self.invoke(m, "request", f)
m.request(f)
f.reply.acked = False
self.invoke(m, "response", f)
m.response(f)
assert sc.jar
assert "cookie" not in f.request.headers
f = f.copy()
f.reply.acked = False
self.invoke(m, "request", f)
m.request(f)
assert f.request.headers["cookie"] == "foo=bar"
def _response(self, s, m, sc, cookie, host):
f = tutils.tflow(req=ntutils.treq(host=host, port=80), resp=True)
f.response.headers["Set-Cookie"] = cookie
self.invoke(m, "response", f)
m.response(f)
return f
def test_response(self):
@ -79,7 +79,7 @@ class TestStickyCookie(mastertest.MasterTest):
c2 = "othercookie=helloworld; Path=/"
f = self._response(s, m, sc, c1, "www.google.com")
f.response.headers["Set-Cookie"] = c2
self.invoke(m, "response", f)
m.response(f)
googlekey = list(sc.jar.keys())[0]
assert len(sc.jar[googlekey].keys()) == 2
@ -96,7 +96,7 @@ class TestStickyCookie(mastertest.MasterTest):
]
for c in cs:
f.response.headers["Set-Cookie"] = c
self.invoke(m, "response", f)
m.response(f)
googlekey = list(sc.jar.keys())[0]
assert len(sc.jar[googlekey].keys()) == len(cs)
@ -108,7 +108,7 @@ class TestStickyCookie(mastertest.MasterTest):
c2 = "somecookie=newvalue; Path=/"
f = self._response(s, m, sc, c1, "www.google.com")
f.response.headers["Set-Cookie"] = c2
self.invoke(m, "response", f)
m.response(f)
googlekey = list(sc.jar.keys())[0]
assert len(sc.jar[googlekey].keys()) == 1
assert list(sc.jar[googlekey]["somecookie"].items())[0][1] == "newvalue"
@ -120,7 +120,7 @@ class TestStickyCookie(mastertest.MasterTest):
# by setting the expire time in the past
f = self._response(s, m, sc, "duffer=zafar; Path=/", "www.google.com")
f.response.headers["Set-Cookie"] = "duffer=; Expires=Thu, 01-Jan-1970 00:00:00 GMT"
self.invoke(m, "response", f)
m.response(f)
assert not sc.jar.keys()
def test_request(self):
@ -128,5 +128,5 @@ class TestStickyCookie(mastertest.MasterTest):
f = self._response(s, m, sc, "SSID=mooo", "www.google.com")
assert "cookie" not in f.request.headers
self.invoke(m, "request", f)
m.request(f)
assert "cookie" in f.request.headers

View File

@ -1,5 +1,3 @@
import mock
from . import tutils
import netlib.tutils
@ -8,24 +6,19 @@ from mitmproxy import flow, proxy, models, controller
class MasterTest:
def invoke(self, master, handler, *message):
with master.handlecontext():
func = getattr(master, handler)
func(*message)
def cycle(self, master, content):
f = tutils.tflow(req=netlib.tutils.treq(content=content))
l = proxy.Log("connect")
l.reply = controller.DummyReply()
master.log(l)
self.invoke(master, "clientconnect", f.client_conn)
self.invoke(master, "clientconnect", f.client_conn)
self.invoke(master, "serverconnect", f.server_conn)
self.invoke(master, "request", f)
master.clientconnect(f.client_conn)
master.serverconnect(f.server_conn)
master.request(f)
if not f.error:
f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content))
self.invoke(master, "response", f)
self.invoke(master, "clientdisconnect", f)
master.response(f)
master.clientdisconnect(f)
return f
def dummy_cycle(self, master, n, content):

View File

@ -25,8 +25,8 @@ class TestConcurrent(mastertest.MasterTest):
)
m.addons.add(m.options, sc)
f1, f2 = tutils.tflow(), tutils.tflow()
self.invoke(m, "request", f1)
self.invoke(m, "request", f2)
m.request(f1)
m.request(f2)
start = time.time()
while time.time() - start < 5:
if f1.reply.state == f2.reply.state == "committed":

View File

@ -187,6 +187,7 @@ class TestDummyReply(object):
reply.ack()
reply.take()
reply.commit()
reply.mark_reset()
reply.reset()
assert reply.state == "unhandled"
@ -196,9 +197,11 @@ class TestDummyReply(object):
reply.ack()
reply.take()
reply.commit()
reply.mark_reset()
assert reply.state == "committed"
reply.reset()
assert reply.state == "unhandled"
def test_del(self):
reply = controller.DummyReply()
reply.__del__()
reply.__del__()

View File

@ -39,7 +39,7 @@ class TestScripts(mastertest.MasterTest):
def test_add_header(self):
m, _ = tscript("add_header.py")
f = tutils.tflow(resp=netutils.tresp())
self.invoke(m, "response", f)
m.response(f)
assert f.response.headers["newheader"] == "foo"
def test_custom_contentviews(self):
@ -54,9 +54,9 @@ class TestScripts(mastertest.MasterTest):
tscript("iframe_injector.py")
m, sc = tscript("iframe_injector.py", "http://example.org/evil_iframe")
flow = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>"))
self.invoke(m, "response", flow)
content = flow.response.content
f = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>"))
m.response(f)
content = f.response.content
assert b'iframe' in content and b'evil_iframe' in content
def test_modify_form(self):
@ -64,23 +64,23 @@ class TestScripts(mastertest.MasterTest):
form_header = Headers(content_type="application/x-www-form-urlencoded")
f = tutils.tflow(req=netutils.treq(headers=form_header))
self.invoke(m, "request", f)
m.request(f)
assert f.request.urlencoded_form[b"mitmproxy"] == b"rocks"
f.request.headers["content-type"] = ""
self.invoke(m, "request", f)
m.request(f)
assert list(f.request.urlencoded_form.items()) == [(b"foo", b"bar")]
def test_modify_querystring(self):
m, sc = tscript("modify_querystring.py")
f = tutils.tflow(req=netutils.treq(path="/search?q=term"))
self.invoke(m, "request", f)
m.request(f)
assert f.request.query["mitmproxy"] == "rocks"
f.request.path = "/"
self.invoke(m, "request", f)
m.request(f)
assert f.request.query["mitmproxy"] == "rocks"
def test_modify_response_body(self):
@ -89,13 +89,13 @@ class TestScripts(mastertest.MasterTest):
m, sc = tscript("modify_response_body.py", "mitmproxy rocks")
f = tutils.tflow(resp=netutils.tresp(content=b"I <3 mitmproxy"))
self.invoke(m, "response", f)
m.response(f)
assert f.response.content == b"I <3 rocks"
def test_redirect_requests(self):
m, sc = tscript("redirect_requests.py")
f = tutils.tflow(req=netutils.treq(host="example.org"))
self.invoke(m, "request", f)
m.request(f)
assert f.request.host == "mitmproxy.org"
def test_har_extractor(self):
@ -119,7 +119,7 @@ class TestScripts(mastertest.MasterTest):
req=netutils.treq(**times),
resp=netutils.tresp(**times)
)
self.invoke(m, "response", f)
m.response(f)
m.addons.remove(sc)
with open(path, "rb") as f:

View File

@ -3,9 +3,9 @@ import io
import netlib.utils
from netlib.http import Headers
from mitmproxy import filt, controller, flow, options
from mitmproxy import filt, flow, options
from mitmproxy.contrib import tnetstring
from mitmproxy.exceptions import FlowReadException
from mitmproxy.exceptions import FlowReadException, Kill
from mitmproxy.models import Error
from mitmproxy.models import Flow
from mitmproxy.models import HTTPFlow
@ -372,14 +372,16 @@ class TestHTTPFlow(object):
assert f.get_state() == f2.get_state()
def test_kill(self):
s = flow.State()
fm = flow.FlowMaster(None, None, s)
fm = mock.Mock()
f = tutils.tflow()
f.reply.handle()
f.intercept(mock.Mock())
f.intercept(fm)
assert fm.handle_intercept.called
assert f.killable
f.kill(fm)
for i in s.view:
assert "killed" in str(i.error)
assert not f.killable
assert fm.error.called
assert f.reply.value == Kill
def test_killall(self):
s = flow.State()
@ -835,8 +837,8 @@ class TestFlowMaster:
f = tutils.tflow()
f.request.host = "nonexistent"
fm.process_new_request(f)
assert "killed" in f.error.msg
fm.request(f)
assert f.reply.value == Kill
class TestRequest: