cleanup mitmproxy.controller, raise Kill in Channel (#1085)

This commit is contained in:
Maximilian Hils 2016-04-14 12:03:29 -07:00
parent bc60c26c7b
commit 423c076c61
16 changed files with 266 additions and 179 deletions

View File

@ -451,7 +451,7 @@ class ConsoleMaster(flow.FlowMaster):
self.ui.clear()
def ticker(self, *userdata):
changed = self.tick(self.masterq, timeout=0)
changed = self.tick(timeout=0)
if changed:
self.loop.draw_screen()
signals.update_settings.send()
@ -467,11 +467,6 @@ class ConsoleMaster(flow.FlowMaster):
handle_mouse = not self.options.no_mouse,
)
self.server.start_slave(
controller.Slave,
controller.Channel(self.masterq, self.should_exit)
)
if self.options.rfile:
ret = self.load_flows_path(self.options.rfile)
if ret and self.state.flow_count():
@ -507,6 +502,7 @@ class ConsoleMaster(flow.FlowMaster):
lambda *args: self.view_flowlist()
)
self.start()
try:
self.loop.run()
except Exception:

View File

@ -2,9 +2,130 @@ from __future__ import absolute_import
from six.moves import queue
import threading
from .exceptions import Kill
class DummyReply:
class Master(object):
"""
The master handles mitmproxy's main event loop.
"""
def __init__(self):
self.event_queue = queue.Queue()
self.should_exit = threading.Event()
def start(self):
self.should_exit.clear()
def run(self):
self.start()
try:
while not self.should_exit.is_set():
# Don't choose a very small timeout in Python 2:
# https://github.com/mitmproxy/mitmproxy/issues/443
# TODO: Lower the timeout value if we move to Python 3.
self.tick(0.1)
finally:
self.shutdown()
def tick(self, timeout):
changed = False
try:
# This endless loop runs until the 'Queue.Empty'
# exception is thrown.
while True:
mtype, obj = self.event_queue.get(timeout=timeout)
handle_func = getattr(self, "handle_" + mtype)
handle_func(obj)
self.event_queue.task_done()
changed = True
except queue.Empty:
pass
return changed
def shutdown(self):
self.should_exit.set()
class ServerMaster(Master):
"""
The ServerMaster adds server thread support to the master.
"""
def __init__(self):
super(ServerMaster, self).__init__()
self.servers = []
def add_server(self, server):
# We give a Channel to the server which can be used to communicate with the master
channel = Channel(self.event_queue, self.should_exit)
server.set_channel(channel)
self.servers.append(server)
def start(self):
super(ServerMaster, self).start()
for server in self.servers:
ServerThread(server).start()
def shutdown(self):
for server in self.servers:
server.shutdown()
super(ServerMaster, self).shutdown()
class ServerThread(threading.Thread):
def __init__(self, server):
self.server = server
super(ServerThread, self).__init__()
address = getattr(self.server, "address", None)
self.name = "ServerThread ({})".format(repr(address))
def run(self):
self.server.serve_forever()
class Channel(object):
"""
The only way for the proxy server to communicate with the master
is to use the channel it has been given.
"""
def __init__(self, q, should_exit):
self.q = q
self.should_exit = should_exit
def ask(self, mtype, m):
"""
Decorate a message with a reply attribute, and send it to the
master. Then wait for a response.
Raises:
Kill: All connections should be closed immediately.
"""
m.reply = Reply(m)
self.q.put((mtype, m))
while not self.should_exit.is_set():
try:
# The timeout is here so we can handle a should_exit event.
g = m.reply.q.get(timeout=0.5)
except queue.Empty: # pragma: no cover
continue
if g == Kill:
raise Kill()
return g
raise Kill()
def tell(self, mtype, m):
"""
Decorate a message with a dummy reply attribute, send it to the
master, then return immediately.
"""
m.reply = DummyReply()
self.q.put((mtype, m))
class DummyReply(object):
"""
A reply object that does nothing. Useful when we need an object to seem
like it has a channel, and during testing.
@ -17,8 +138,11 @@ class DummyReply:
self.acked = True
class Reply:
# Special value to distinguish the case where no reply was sent
NO_REPLY = object()
class Reply(object):
"""
Messages sent through a channel are decorated with a "reply" attribute.
This object is used to respond to the message through the return
@ -30,113 +154,10 @@ class Reply:
self.q = queue.Queue()
self.acked = False
def __call__(self, msg=None):
def __call__(self, msg=NO_REPLY):
if not self.acked:
self.acked = True
if msg is None:
if msg is NO_REPLY:
self.q.put(self.obj)
else:
self.q.put(msg)
class Channel:
def __init__(self, q, should_exit):
self.q = q
self.should_exit = should_exit
def ask(self, mtype, m):
"""
Decorate a message with a reply attribute, and send it to the
master. then wait for a response.
"""
m.reply = Reply(m)
self.q.put((mtype, m))
while not self.should_exit.is_set():
try:
# The timeout is here so we can handle a should_exit event.
g = m.reply.q.get(timeout=0.5)
except queue.Empty: # pragma: no cover
continue
return g
def tell(self, mtype, m):
"""
Decorate a message with a dummy reply attribute, send it to the
master, then return immediately.
"""
m.reply = DummyReply()
self.q.put((mtype, m))
class Slave(threading.Thread):
"""
Slaves get a channel end-point through which they can send messages to
the master.
"""
def __init__(self, channel, server):
self.channel, self.server = channel, server
self.server.set_channel(channel)
threading.Thread.__init__(self)
self.name = "SlaveThread ({})".format(repr(self.server.address))
def run(self):
self.server.serve_forever()
class Master(object):
"""
Masters get and respond to messages from slaves.
"""
def __init__(self, server):
"""
server may be None if no server is needed.
"""
self.server = server
self.masterq = queue.Queue()
self.should_exit = threading.Event()
def tick(self, q, timeout):
changed = False
try:
# This endless loop runs until the 'Queue.Empty'
# exception is thrown. If more than one request is in
# the queue, this speeds up every request by 0.1 seconds,
# because get_input(..) function is not blocking.
while True:
msg = q.get(timeout=timeout)
self.handle(*msg)
q.task_done()
changed = True
except queue.Empty:
pass
return changed
def run(self):
self.should_exit.clear()
self.server.start_slave(Slave, Channel(self.masterq, self.should_exit))
while not self.should_exit.is_set():
# Don't choose a very small timeout in Python 2:
# https://github.com/mitmproxy/mitmproxy/issues/443
# TODO: Lower the timeout value if we move to Python 3.
self.tick(self.masterq, 0.1)
self.shutdown()
def handle(self, mtype, obj):
c = "handle_" + mtype
m = getattr(self, c, None)
if m:
m(obj)
else:
obj.reply()
def shutdown(self):
if not self.should_exit.is_set():
self.should_exit.set()
if self.server:
self.server.shutdown()

View File

@ -343,15 +343,8 @@ class DumpMaster(flow.FlowMaster):
self._process_flow(f)
return f
def shutdown(self): # pragma: no cover
return flow.FlowMaster.shutdown(self)
def run(self): # pragma: no cover
if self.o.rfile and not self.o.keepserving:
self.shutdown()
return
try:
return super(DumpMaster, self).run()
except BaseException:
self.shutdown()
raise
super(DumpMaster, self).run()

View File

@ -17,6 +17,13 @@ class ProxyException(Exception):
super(ProxyException, self).__init__(message)
class Kill(ProxyException):
"""
Signal that both client and server connection(s) should be killed immediately.
"""
pass
class ProtocolException(ProxyException):
pass

View File

@ -8,11 +8,9 @@ from abc import abstractmethod, ABCMeta
import hashlib
import six
from six.moves import http_cookies, http_cookiejar
from six.moves import http_cookies, http_cookiejar, urllib
import os
import re
import time
from six.moves import urllib
from netlib import wsgi
from netlib.exceptions import HttpException
@ -21,8 +19,8 @@ from . import controller, tnetstring, filt, script, version, flow_format_compat
from .onboarding import app
from .proxy.config import HostMatcher
from .protocol.http_replay import RequestReplayThread
from .protocol import Kill
from .models import ClientConnection, ServerConnection, HTTPResponse, HTTPFlow, HTTPRequest
from .exceptions import Kill
from .models import ClientConnection, ServerConnection, HTTPFlow, HTTPRequest
class AppRegistry:
@ -630,10 +628,19 @@ class State(object):
self.flows.kill_all(master)
class FlowMaster(controller.Master):
class FlowMaster(controller.ServerMaster):
@property
def server(self):
# At some point, we may want to have support for multiple servers.
# For now, this suffices.
if len(self.servers) > 0:
return self.servers[0]
def __init__(self, server, state):
controller.Master.__init__(self, server)
super(FlowMaster, self).__init__()
if server:
self.add_server(server)
self.state = state
self.server_playback = None
self.client_playback = None
@ -695,7 +702,7 @@ class FlowMaster(controller.Master):
except script.ScriptException as e:
return traceback.format_exc(e)
if use_reloader:
script.reloader.watch(s, lambda: self.masterq.put(("script_change", s)))
script.reloader.watch(s, lambda: self.event_queue.put(("script_change", s)))
self.scripts.append(s)
def _run_single_script_hook(self, script_obj, name, *args, **kwargs):
@ -808,7 +815,7 @@ class FlowMaster(controller.Master):
return True
return None
def tick(self, q, timeout):
def tick(self, timeout):
if self.client_playback:
stop = (
self.client_playback.done() and
@ -833,7 +840,7 @@ class FlowMaster(controller.Master):
self.stop_server_playback()
if exit:
self.shutdown()
return super(FlowMaster, self).tick(q, timeout)
return super(FlowMaster, self).tick(timeout)
def duplicate_flow(self, f):
return self.load_flow(f.copy())
@ -942,7 +949,7 @@ class FlowMaster(controller.Master):
rt = RequestReplayThread(
self.server.config,
f,
self.masterq if run_scripthooks else False,
self.event_queue if run_scripthooks else False,
self.should_exit
)
rt.start() # pragma: no cover
@ -1066,7 +1073,6 @@ class FlowMaster(controller.Master):
m.reply()
def shutdown(self):
self.unload_scripts()
super(FlowMaster, self).shutdown()
# Add all flows that are still active
@ -1076,6 +1082,8 @@ class FlowMaster(controller.Master):
self.stream.add(i)
self.stop_stream()
self.unload_scripts()
def start_stream(self, fp, filt):
self.stream = FilteredFlowWriter(fp, filt)

View File

@ -4,6 +4,7 @@ import uuid
from .. import stateobject, utils, version
from .connections import ClientConnection, ServerConnection
from ..exceptions import Kill
class Error(stateobject.StateObject):
@ -139,8 +140,6 @@ class Flow(stateobject.StateObject):
"""
Kill this request.
"""
from ..protocol import Kill
self.error = Error("Connection killed")
self.intercepted = False
self.reply(Kill)

View File

@ -26,7 +26,7 @@ as late as possible; this makes server replay without any outgoing connections p
"""
from __future__ import (absolute_import, print_function, division)
from .base import Layer, ServerConnectionMixin, Kill
from .base import Layer, ServerConnectionMixin
from .tls import TlsLayer
from .tls import is_tls_record_magic
from .tls import TlsClientHello
@ -36,7 +36,7 @@ from .http2 import Http2Layer
from .rawtcp import RawTCPLayer
__all__ = [
"Layer", "ServerConnectionMixin", "Kill",
"Layer", "ServerConnectionMixin",
"TlsLayer", "is_tls_record_magic", "TlsClientHello",
"UpstreamConnectLayer",
"Http1Layer",

View File

@ -189,10 +189,3 @@ class ServerConnectionMixin(object):
),
sys.exc_info()[2]
)
class Kill(Exception):
"""
Signal that both client and server connection(s) should be killed immediately.
"""

View File

@ -21,7 +21,7 @@ from ..models import (
expect_continue_response
)
from .base import Layer, Kill
from .base import Layer
class _HttpTransmissionLayer(Layer):
@ -194,13 +194,9 @@ class HttpLayer(Layer):
# response was set by an inline script.
# we now need to emulate the responseheaders hook.
flow = self.channel.ask("responseheaders", flow)
if flow == Kill:
raise Kill()
self.log("response", "debug", [repr(flow.response)])
flow = self.channel.ask("response", flow)
if flow == Kill:
raise Kill()
self.send_response_to_client(flow)
if self.check_close_connection(flow):
@ -315,8 +311,6 @@ class HttpLayer(Layer):
# call the appropriate script hook - this is an opportunity for an
# inline script to set flow.stream = True
flow = self.channel.ask("responseheaders", flow)
if flow == Kill:
raise Kill()
if flow.response.stream:
flow.response.data.content = None
@ -352,8 +346,6 @@ class HttpLayer(Layer):
flow.request.scheme = "https" if self.__initial_server_tls else "http"
request_reply = self.channel.ask("request", flow)
if request_reply == Kill:
raise Kill()
if isinstance(request_reply, HTTPResponse):
flow.response = request_reply
return

View File

@ -7,8 +7,7 @@ from netlib.http import http1
from ..controller import Channel
from ..models import Error, HTTPResponse, ServerConnection, make_connect_request
from .base import Kill
from ..exceptions import Kill
# TODO: Doesn't really belong into mitmproxy.protocol...
@ -16,14 +15,14 @@ from .base import Kill
class RequestReplayThread(threading.Thread):
name = "RequestReplayThread"
def __init__(self, config, flow, masterq, should_exit):
def __init__(self, config, flow, event_queue, should_exit):
"""
masterqueue can be a queue or None, if no scripthooks should be
event_queue can be a queue or None, if no scripthooks should be
processed.
"""
self.config, self.flow = config, flow
if masterq:
self.channel = Channel(masterq, should_exit)
if event_queue:
self.channel = Channel(event_queue, should_exit)
else:
self.channel = None
super(RequestReplayThread, self).__init__()
@ -37,9 +36,7 @@ class RequestReplayThread(threading.Thread):
# If we have a channel, run script hooks.
if self.channel:
request_reply = self.channel.ask("request", self.flow)
if request_reply == Kill:
raise Kill()
elif isinstance(request_reply, HTTPResponse):
if isinstance(request_reply, HTTPResponse):
self.flow.response = request_reply
if not self.flow.response:

View File

@ -8,8 +8,7 @@ import six
from netlib import tcp
from netlib.exceptions import TcpException
from netlib.http.http1 import assemble_response
from ..exceptions import ProtocolException, ServerException, ClientHandshakeException
from ..protocol import Kill
from ..exceptions import ProtocolException, ServerException, ClientHandshakeException, Kill
from ..models import ClientConnection, make_error_response
from .modes import HttpUpstreamProxy, HttpProxy, ReverseProxy, TransparentProxy, Socks5Proxy
from .root_context import RootContext, Log
@ -21,7 +20,10 @@ class DummyServer:
def __init__(self, config):
self.config = config
def start_slave(self, *args):
def set_channel(self, channel):
pass
def serve_forever(self):
pass
def shutdown(self):
@ -47,10 +49,6 @@ class ProxyServer(tcp.TCPServer):
)
self.channel = None
def start_slave(self, klass, channel):
slave = klass(channel, self)
slave.start()
def set_channel(self, channel):
self.channel = channel
@ -112,12 +110,9 @@ class ConnectionHandler(object):
self.log("clientconnect", "info")
root_layer = self._create_root_layer()
root_layer = self.channel.ask("clientconnect", root_layer)
if root_layer == Kill:
def root_layer():
raise Kill()
try:
root_layer = self.channel.ask("clientconnect", root_layer)
root_layer()
except Kill:
self.log("Connection killed", "info")

View File

@ -173,20 +173,15 @@ class WebMaster(flow.FlowMaster):
if self.options.app:
self.start_app(self.options.app_host, self.options.app_port)
def tick(self):
flow.FlowMaster.tick(self, self.masterq, timeout=0)
def run(self): # pragma: no cover
self.server.start_slave(
controller.Slave,
controller.Channel(self.masterq, self.should_exit)
)
iol = tornado.ioloop.IOLoop.instance()
http_server = tornado.httpserver.HTTPServer(self.app)
http_server.listen(self.options.wport)
tornado.ioloop.PeriodicCallback(self.tick, 5).start()
iol.add_callback(self.start)
tornado.ioloop.PeriodicCallback(lambda: self.tick(timeout=0), 5).start()
try:
iol.start()
except (Stop, KeyboardInterrupt):

View File

@ -1,11 +1,105 @@
import mock
from mitmproxy import controller
from threading import Thread, Event
from mock import Mock
from mitmproxy.controller import Reply, DummyReply, Channel, ServerThread, ServerMaster, Master
from six.moves import queue
from mitmproxy.exceptions import Kill
from mitmproxy.proxy import DummyServer
from netlib.tutils import raises
class TestMaster:
class TestMaster(object):
def test_simple(self):
def test_default_handler(self):
m = controller.Master(None)
msg = mock.MagicMock()
m.handle("type", msg)
assert msg.reply.call_count == 1
class DummyMaster(Master):
def handle_panic(self, _):
m.should_exit.set()
def tick(self, timeout):
# Speed up test
super(DummyMaster, self).tick(0)
m = DummyMaster()
assert not m.should_exit.is_set()
m.event_queue.put(("panic", 42))
m.run()
assert m.should_exit.is_set()
class TestServerMaster(object):
def test_simple(self):
m = ServerMaster()
s = DummyServer(None)
m.add_server(s)
m.start()
m.shutdown()
m.start()
m.shutdown()
class TestServerThread(object):
def test_simple(self):
m = Mock()
t = ServerThread(m)
t.run()
assert m.serve_forever.called
class TestChannel(object):
def test_tell(self):
q = queue.Queue()
channel = Channel(q, Event())
m = Mock()
channel.tell("test", m)
assert q.get() == ("test", m)
assert m.reply
def test_ask_simple(self):
q = queue.Queue()
def reply():
m, obj = q.get()
assert m == "test"
obj.reply(42)
Thread(target=reply).start()
channel = Channel(q, Event())
assert channel.ask("test", Mock()) == 42
def test_ask_shutdown(self):
q = queue.Queue()
done = Event()
done.set()
channel = Channel(q, done)
with raises(Kill):
channel.ask("test", Mock())
class TestDummyReply(object):
def test_simple(self):
reply = DummyReply()
assert not reply.acked
reply()
assert reply.acked
class TestReply(object):
def test_simple(self):
reply = Reply(42)
assert not reply.acked
reply("foo")
assert reply.acked
assert reply.q.get() == "foo"
def test_default(self):
reply = Reply(42)
reply()
assert reply.q.get() == 42
def test_reply_none(self):
reply = Reply(42)
reply(None)
assert reply.q.get() is None

View File

@ -116,9 +116,8 @@ class TestClientPlaybackState:
c.clear(c.current)
assert c.done()
q = queue.Queue()
fm.state.clear()
fm.tick(q, timeout=0)
fm.tick(timeout=0)
fm.stop_client_playback()
assert not fm.client_playback
@ -858,9 +857,8 @@ class TestFlowMaster:
assert not fm.start_client_playback(pb, False)
fm.client_playback.testing = True
q = queue.Queue()
assert not fm.state.flow_count()
fm.tick(q, 0)
fm.tick(0)
assert fm.state.flow_count()
f.error = Error("error")
@ -904,8 +902,7 @@ class TestFlowMaster:
assert not fm.do_server_playback(r)
assert fm.do_server_playback(tutils.tflow())
q = queue.Queue()
fm.tick(q, 0)
fm.tick(0)
assert fm.should_exit.is_set()
fm.stop_server_playback()

View File

@ -175,7 +175,7 @@ class TestDummyServer:
def test_simple(self):
d = DummyServer(None)
d.start_slave()
d.set_channel(None)
d.shutdown()

View File

@ -13,7 +13,7 @@ from netlib.tutils import raises
from pathod import pathoc, pathod
from mitmproxy.proxy.config import HostMatcher
from mitmproxy.protocol import Kill
from mitmproxy.exceptions import Kill
from mitmproxy.models import Error, HTTPResponse
from . import tutils, tservers
@ -126,7 +126,7 @@ class TcpMixin:
i2 = self.pathod("306")
self._ignore_off()
self.master.masterq.join()
self.master.event_queue.join()
assert n.status_code == 304
assert i.status_code == 305
@ -172,7 +172,7 @@ class TcpMixin:
i2 = self.pathod("306")
self._tcpproxy_off()
self.master.masterq.join()
self.master.event_queue.join()
assert n.status_code == 304
assert i.status_code == 305