Mandate that all handlers must be wrapped, make tests pass

mitmproxy, mitmdump and mitmweb masters still to be done
This commit is contained in:
Aldo Cortesi 2016-05-26 13:14:57 +12:00
parent 23efee9813
commit 08f2a0524e
5 changed files with 44 additions and 34 deletions

View File

@ -42,12 +42,12 @@ class Master(object):
while True:
mtype, obj = self.event_queue.get(timeout=timeout)
handle_func = getattr(self, "handle_" + mtype)
# if not handle_func.func_dict.get("handler"):
# raise ControlError(
# "Handler function %s is not decorated with controller.handler"%(
# handle_func
# )
# )
if not handle_func.func_dict.get("handler"):
raise ControlError(
"Handler function %s is not decorated with controller.handler"%(
handle_func
)
)
handle_func(obj)
self.event_queue.task_done()
changed = True
@ -160,10 +160,11 @@ NO_REPLY = object()
def handler(f):
@functools.wraps(f)
def wrapper(obj, message, *args, **kwargs):
def wrapper(*args, **kwargs):
message = args[-1]
if not hasattr(message, "reply"):
raise ControlError("Message %s has no reply attribute"%message)
ret = f(obj, message, *args, **kwargs)
ret = f(*args, **kwargs)
if not message.reply.acked and not message.reply.taken:
message.reply()
return ret

View File

@ -6,10 +6,11 @@ import itertools
from netlib import tcp
from netlib.utils import bytes_to_escaped_str, pretty_size
from . import flow, filt, contentviews
from . import flow, filt, contentviews, controller
from .exceptions import ContentViewException, FlowReadException, ScriptException
class DumpError(Exception):
pass

View File

@ -1017,6 +1017,7 @@ class FlowMaster(controller.ServerMaster):
self.client_playback.clear(f)
return f
@controller.handler
def handle_request(self, f):
if f.live:
app = self.apps.get(f.request)
@ -1039,6 +1040,7 @@ class FlowMaster(controller.ServerMaster):
self.run_script_hook("request", f)
return f
@controller.handler
def handle_responseheaders(self, f):
try:
if self.stream_large_bodies:
@ -1046,12 +1048,10 @@ class FlowMaster(controller.ServerMaster):
except HttpException:
f.reply(Kill)
return
self.run_script_hook("responseheaders", f)
f.reply()
return f
@controller.handler
def handle_response(self, f):
self.active_flows.discard(f)
self.state.update_flow(f)
@ -1099,13 +1099,14 @@ class FlowMaster(controller.ServerMaster):
self.add_event('"{}" reloaded.'.format(s.filename), 'info')
return ok
@controller.handler
def handle_tcp_open(self, flow):
# TODO: This would break mitmproxy currently.
# self.state.add_flow(flow)
self.active_flows.add(flow)
self.run_script_hook("tcp_open", flow)
flow.reply()
@controller.handler
def handle_tcp_message(self, flow):
self.run_script_hook("tcp_message", flow)
message = flow.messages[-1]
@ -1116,22 +1117,21 @@ class FlowMaster(controller.ServerMaster):
direction=direction,
), "info")
self.add_event(clean_bin(message.content), "debug")
flow.reply()
@controller.handler
def handle_tcp_error(self, flow):
self.add_event("Error in TCP connection to {}: {}".format(
repr(flow.server_conn.address),
flow.error
), "info")
self.run_script_hook("tcp_error", flow)
flow.reply()
@controller.handler
def handle_tcp_close(self, flow):
self.active_flows.discard(flow)
if self.stream:
self.stream.add(flow)
self.run_script_hook("tcp_close", flow)
flow.reply()
def shutdown(self):
super(FlowMaster, self).shutdown()

View File

@ -461,9 +461,9 @@ class TestFlow(object):
fm = flow.FlowMaster(None, s)
f = tutils.tflow()
f.intercept(mock.Mock())
assert not f.reply.acked
f.kill(fm)
assert f.reply.acked
for i in s.view:
assert "killed" in str(i.error)
def test_killall(self):
s = flow.State()
@ -475,11 +475,9 @@ class TestFlow(object):
f = tutils.tflow()
fm.handle_request(f)
for i in s.view:
assert not i.reply.acked
s.killall(fm)
for i in s.view:
assert i.reply.acked
assert "killed" in str(i.error)
def test_accept_intercept(self):
f = tutils.tflow()

View File

@ -12,6 +12,7 @@ from netlib.http import authentication, http1
from netlib.tutils import raises
from pathod import pathoc, pathod
from mitmproxy import controller
from mitmproxy.proxy.config import HostMatcher
from mitmproxy.exceptions import Kill
from mitmproxy.models import Error, HTTPResponse, HTTPFlow
@ -623,6 +624,7 @@ class TestProxySSL(tservers.HTTPProxyTest):
class MasterRedirectRequest(tservers.TestMaster):
redirect_port = None # Set by TestRedirectRequest
@controller.handler
def handle_request(self, f):
if f.request.path == "/p/201":
@ -636,6 +638,7 @@ class MasterRedirectRequest(tservers.TestMaster):
f.request.port = self.redirect_port
super(MasterRedirectRequest, self).handle_request(f)
@controller.handler
def handle_response(self, f):
f.response.content = str(f.client_conn.address.port)
f.response.headers["server-conn-id"] = str(f.server_conn.source_address.port)
@ -689,10 +692,9 @@ class MasterStreamRequest(tservers.TestMaster):
"""
Enables the stream flag on the flow for all requests
"""
@controller.handler
def handle_responseheaders(self, f):
f.response.stream = True
f.reply()
class TestStreamRequest(tservers.HTTPProxyTest):
@ -739,7 +741,7 @@ class TestStreamRequest(tservers.HTTPProxyTest):
class MasterFakeResponse(tservers.TestMaster):
@controller.handler
def handle_request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp())
f.reply(resp)
@ -767,6 +769,7 @@ class TestServerConnect(tservers.HTTPProxyTest):
class MasterKillRequest(tservers.TestMaster):
@controller.handler
def handle_request(self, f):
f.reply(Kill)
@ -783,6 +786,7 @@ class TestKillRequest(tservers.HTTPProxyTest):
class MasterKillResponse(tservers.TestMaster):
@controller.handler
def handle_response(self, f):
f.reply(Kill)
@ -812,6 +816,7 @@ class TestTransparentResolveError(tservers.TransparentProxyTest):
class MasterIncomplete(tservers.TestMaster):
@controller.handler
def handle_request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp())
resp.content = None
@ -930,7 +935,9 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest):
k = [0] # variable scope workaround: put into array
_func = getattr(master, attr)
def handler(f):
@controller.handler
def handler(*args):
f = args[-1]
k[0] += 1
if not (k[0] in exclude):
f.client_conn.finish()
@ -940,11 +947,14 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest):
setattr(master, attr, handler)
kill_requests(self.chain[1].tmaster, "handle_request",
exclude=[
# fail first request
2, # allow second request
])
kill_requests(
self.chain[1].tmaster,
"handle_request",
exclude = [
# fail first request
2, # allow second request
]
)
kill_requests(self.chain[0].tmaster, "handle_request",
exclude=[
@ -1004,10 +1014,10 @@ class AddUpstreamCertsToClientChainMixin:
ssl = True
servercert = tutils.test_data.path("data/trusted-server.crt")
ssloptions = pathod.SSLOptions(
cn="trusted-cert",
certs=[
("trusted-cert", servercert)
]
cn="trusted-cert",
certs=[
("trusted-cert", servercert)
]
)
def test_add_upstream_certs_to_client_chain(self):