From 423c076c61140c2e953313793263a4cac71e33ca Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 14 Apr 2016 12:03:29 -0700 Subject: [PATCH] cleanup mitmproxy.controller, raise Kill in Channel (#1085) --- mitmproxy/console/__init__.py | 8 +- mitmproxy/controller.py | 195 +++++++++++++++++------------- mitmproxy/dump.py | 9 +- mitmproxy/exceptions.py | 7 ++ mitmproxy/flow.py | 32 +++-- mitmproxy/models/flow.py | 3 +- mitmproxy/protocol/__init__.py | 4 +- mitmproxy/protocol/base.py | 7 -- mitmproxy/protocol/http.py | 10 +- mitmproxy/protocol/http_replay.py | 15 +-- mitmproxy/proxy/server.py | 17 +-- mitmproxy/web/__init__.py | 11 +- test/mitmproxy/test_controller.py | 110 +++++++++++++++-- test/mitmproxy/test_flow.py | 9 +- test/mitmproxy/test_proxy.py | 2 +- test/mitmproxy/test_server.py | 6 +- 16 files changed, 266 insertions(+), 179 deletions(-) diff --git a/mitmproxy/console/__init__.py b/mitmproxy/console/__init__.py index 381c133db..32e4d33c2 100644 --- a/mitmproxy/console/__init__.py +++ b/mitmproxy/console/__init__.py @@ -451,7 +451,7 @@ class ConsoleMaster(flow.FlowMaster): self.ui.clear() def ticker(self, *userdata): - changed = self.tick(self.masterq, timeout=0) + changed = self.tick(timeout=0) if changed: self.loop.draw_screen() signals.update_settings.send() @@ -467,11 +467,6 @@ class ConsoleMaster(flow.FlowMaster): handle_mouse = not self.options.no_mouse, ) - self.server.start_slave( - controller.Slave, - controller.Channel(self.masterq, self.should_exit) - ) - if self.options.rfile: ret = self.load_flows_path(self.options.rfile) if ret and self.state.flow_count(): @@ -507,6 +502,7 @@ class ConsoleMaster(flow.FlowMaster): lambda *args: self.view_flowlist() ) + self.start() try: self.loop.run() except Exception: diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 84f243a0d..81978a098 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -2,44 +2,93 @@ from __future__ import absolute_import from six.moves import queue import threading +from .exceptions import Kill -class DummyReply: +class Master(object): """ - A reply object that does nothing. Useful when we need an object to seem - like it has a channel, and during testing. + The master handles mitmproxy's main event loop. """ def __init__(self): - self.acked = False + self.event_queue = queue.Queue() + self.should_exit = threading.Event() - def __call__(self, msg=False): - self.acked = True + def start(self): + self.should_exit.clear() + + def run(self): + self.start() + try: + while not self.should_exit.is_set(): + # Don't choose a very small timeout in Python 2: + # https://github.com/mitmproxy/mitmproxy/issues/443 + # TODO: Lower the timeout value if we move to Python 3. + self.tick(0.1) + finally: + self.shutdown() + + def tick(self, timeout): + changed = False + try: + # This endless loop runs until the 'Queue.Empty' + # exception is thrown. + while True: + mtype, obj = self.event_queue.get(timeout=timeout) + handle_func = getattr(self, "handle_" + mtype) + handle_func(obj) + self.event_queue.task_done() + changed = True + except queue.Empty: + pass + return changed + + def shutdown(self): + self.should_exit.set() -class Reply: - +class ServerMaster(Master): """ - Messages sent through a channel are decorated with a "reply" attribute. - This object is used to respond to the message through the return - channel. + The ServerMaster adds server thread support to the master. """ - def __init__(self, obj): - self.obj = obj - self.q = queue.Queue() - self.acked = False + def __init__(self): + super(ServerMaster, self).__init__() + self.servers = [] - def __call__(self, msg=None): - if not self.acked: - self.acked = True - if msg is None: - self.q.put(self.obj) - else: - self.q.put(msg) + def add_server(self, server): + # We give a Channel to the server which can be used to communicate with the master + channel = Channel(self.event_queue, self.should_exit) + server.set_channel(channel) + self.servers.append(server) + + def start(self): + super(ServerMaster, self).start() + for server in self.servers: + ServerThread(server).start() + + def shutdown(self): + for server in self.servers: + server.shutdown() + super(ServerMaster, self).shutdown() -class Channel: +class ServerThread(threading.Thread): + def __init__(self, server): + self.server = server + super(ServerThread, self).__init__() + address = getattr(self.server, "address", None) + self.name = "ServerThread ({})".format(repr(address)) + + def run(self): + self.server.serve_forever() + + +class Channel(object): + """ + The only way for the proxy server to communicate with the master + is to use the channel it has been given. + """ def __init__(self, q, should_exit): self.q = q @@ -47,8 +96,11 @@ class Channel: def ask(self, mtype, m): """ - Decorate a message with a reply attribute, and send it to the - master. then wait for a response. + Decorate a message with a reply attribute, and send it to the + master. Then wait for a response. + + Raises: + Kill: All connections should be closed immediately. """ m.reply = Reply(m) self.q.put((mtype, m)) @@ -58,85 +110,54 @@ class Channel: g = m.reply.q.get(timeout=0.5) except queue.Empty: # pragma: no cover continue + if g == Kill: + raise Kill() return g + raise Kill() + def tell(self, mtype, m): """ - Decorate a message with a dummy reply attribute, send it to the - master, then return immediately. + Decorate a message with a dummy reply attribute, send it to the + master, then return immediately. """ m.reply = DummyReply() self.q.put((mtype, m)) -class Slave(threading.Thread): - +class DummyReply(object): """ - Slaves get a channel end-point through which they can send messages to - the master. + A reply object that does nothing. Useful when we need an object to seem + like it has a channel, and during testing. """ - def __init__(self, channel, server): - self.channel, self.server = channel, server - self.server.set_channel(channel) - threading.Thread.__init__(self) - self.name = "SlaveThread ({})".format(repr(self.server.address)) + def __init__(self): + self.acked = False - def run(self): - self.server.serve_forever() + def __call__(self, msg=False): + self.acked = True -class Master(object): +# Special value to distinguish the case where no reply was sent +NO_REPLY = object() + +class Reply(object): """ - Masters get and respond to messages from slaves. + Messages sent through a channel are decorated with a "reply" attribute. + This object is used to respond to the message through the return + channel. """ - def __init__(self, server): - """ - server may be None if no server is needed. - """ - self.server = server - self.masterq = queue.Queue() - self.should_exit = threading.Event() + def __init__(self, obj): + self.obj = obj + self.q = queue.Queue() + self.acked = False - def tick(self, q, timeout): - changed = False - try: - # This endless loop runs until the 'Queue.Empty' - # exception is thrown. If more than one request is in - # the queue, this speeds up every request by 0.1 seconds, - # because get_input(..) function is not blocking. - while True: - msg = q.get(timeout=timeout) - self.handle(*msg) - q.task_done() - changed = True - except queue.Empty: - pass - return changed - - def run(self): - self.should_exit.clear() - self.server.start_slave(Slave, Channel(self.masterq, self.should_exit)) - while not self.should_exit.is_set(): - - # Don't choose a very small timeout in Python 2: - # https://github.com/mitmproxy/mitmproxy/issues/443 - # TODO: Lower the timeout value if we move to Python 3. - self.tick(self.masterq, 0.1) - self.shutdown() - - def handle(self, mtype, obj): - c = "handle_" + mtype - m = getattr(self, c, None) - if m: - m(obj) - else: - obj.reply() - - def shutdown(self): - if not self.should_exit.is_set(): - self.should_exit.set() - if self.server: - self.server.shutdown() + def __call__(self, msg=NO_REPLY): + if not self.acked: + self.acked = True + if msg is NO_REPLY: + self.q.put(self.obj) + else: + self.q.put(msg) diff --git a/mitmproxy/dump.py b/mitmproxy/dump.py index 2f519d28d..93487b673 100644 --- a/mitmproxy/dump.py +++ b/mitmproxy/dump.py @@ -343,15 +343,8 @@ class DumpMaster(flow.FlowMaster): self._process_flow(f) return f - def shutdown(self): # pragma: no cover - return flow.FlowMaster.shutdown(self) - def run(self): # pragma: no cover if self.o.rfile and not self.o.keepserving: self.shutdown() return - try: - return super(DumpMaster, self).run() - except BaseException: - self.shutdown() - raise + super(DumpMaster, self).run() \ No newline at end of file diff --git a/mitmproxy/exceptions.py b/mitmproxy/exceptions.py index 53916d1fb..d600f2e38 100644 --- a/mitmproxy/exceptions.py +++ b/mitmproxy/exceptions.py @@ -17,6 +17,13 @@ class ProxyException(Exception): super(ProxyException, self).__init__(message) +class Kill(ProxyException): + """ + Signal that both client and server connection(s) should be killed immediately. + """ + pass + + class ProtocolException(ProxyException): pass diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 9c3785857..13f057a4e 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -8,11 +8,9 @@ from abc import abstractmethod, ABCMeta import hashlib import six -from six.moves import http_cookies, http_cookiejar +from six.moves import http_cookies, http_cookiejar, urllib import os import re -import time -from six.moves import urllib from netlib import wsgi from netlib.exceptions import HttpException @@ -21,8 +19,8 @@ from . import controller, tnetstring, filt, script, version, flow_format_compat from .onboarding import app from .proxy.config import HostMatcher from .protocol.http_replay import RequestReplayThread -from .protocol import Kill -from .models import ClientConnection, ServerConnection, HTTPResponse, HTTPFlow, HTTPRequest +from .exceptions import Kill +from .models import ClientConnection, ServerConnection, HTTPFlow, HTTPRequest class AppRegistry: @@ -630,10 +628,19 @@ class State(object): self.flows.kill_all(master) -class FlowMaster(controller.Master): +class FlowMaster(controller.ServerMaster): + + @property + def server(self): + # At some point, we may want to have support for multiple servers. + # For now, this suffices. + if len(self.servers) > 0: + return self.servers[0] def __init__(self, server, state): - controller.Master.__init__(self, server) + super(FlowMaster, self).__init__() + if server: + self.add_server(server) self.state = state self.server_playback = None self.client_playback = None @@ -695,7 +702,7 @@ class FlowMaster(controller.Master): except script.ScriptException as e: return traceback.format_exc(e) if use_reloader: - script.reloader.watch(s, lambda: self.masterq.put(("script_change", s))) + script.reloader.watch(s, lambda: self.event_queue.put(("script_change", s))) self.scripts.append(s) def _run_single_script_hook(self, script_obj, name, *args, **kwargs): @@ -808,7 +815,7 @@ class FlowMaster(controller.Master): return True return None - def tick(self, q, timeout): + def tick(self, timeout): if self.client_playback: stop = ( self.client_playback.done() and @@ -833,7 +840,7 @@ class FlowMaster(controller.Master): self.stop_server_playback() if exit: self.shutdown() - return super(FlowMaster, self).tick(q, timeout) + return super(FlowMaster, self).tick(timeout) def duplicate_flow(self, f): return self.load_flow(f.copy()) @@ -942,7 +949,7 @@ class FlowMaster(controller.Master): rt = RequestReplayThread( self.server.config, f, - self.masterq if run_scripthooks else False, + self.event_queue if run_scripthooks else False, self.should_exit ) rt.start() # pragma: no cover @@ -1066,7 +1073,6 @@ class FlowMaster(controller.Master): m.reply() def shutdown(self): - self.unload_scripts() super(FlowMaster, self).shutdown() # Add all flows that are still active @@ -1076,6 +1082,8 @@ class FlowMaster(controller.Master): self.stream.add(i) self.stop_stream() + self.unload_scripts() + def start_stream(self, fp, filt): self.stream = FilteredFlowWriter(fp, filt) diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index 45b3b5e05..594147ec8 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -4,6 +4,7 @@ import uuid from .. import stateobject, utils, version from .connections import ClientConnection, ServerConnection +from ..exceptions import Kill class Error(stateobject.StateObject): @@ -139,8 +140,6 @@ class Flow(stateobject.StateObject): """ Kill this request. """ - from ..protocol import Kill - self.error = Error("Connection killed") self.intercepted = False self.reply(Kill) diff --git a/mitmproxy/protocol/__init__.py b/mitmproxy/protocol/__init__.py index d44e25e93..3d9fa7d4f 100644 --- a/mitmproxy/protocol/__init__.py +++ b/mitmproxy/protocol/__init__.py @@ -26,7 +26,7 @@ as late as possible; this makes server replay without any outgoing connections p """ from __future__ import (absolute_import, print_function, division) -from .base import Layer, ServerConnectionMixin, Kill +from .base import Layer, ServerConnectionMixin from .tls import TlsLayer from .tls import is_tls_record_magic from .tls import TlsClientHello @@ -36,7 +36,7 @@ from .http2 import Http2Layer from .rawtcp import RawTCPLayer __all__ = [ - "Layer", "ServerConnectionMixin", "Kill", + "Layer", "ServerConnectionMixin", "TlsLayer", "is_tls_record_magic", "TlsClientHello", "UpstreamConnectLayer", "Http1Layer", diff --git a/mitmproxy/protocol/base.py b/mitmproxy/protocol/base.py index 41d47281f..536f2753c 100644 --- a/mitmproxy/protocol/base.py +++ b/mitmproxy/protocol/base.py @@ -189,10 +189,3 @@ class ServerConnectionMixin(object): ), sys.exc_info()[2] ) - - -class Kill(Exception): - - """ - Signal that both client and server connection(s) should be killed immediately. - """ diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index db24ca924..bdda80cd2 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -21,7 +21,7 @@ from ..models import ( expect_continue_response ) -from .base import Layer, Kill +from .base import Layer class _HttpTransmissionLayer(Layer): @@ -194,13 +194,9 @@ class HttpLayer(Layer): # response was set by an inline script. # we now need to emulate the responseheaders hook. flow = self.channel.ask("responseheaders", flow) - if flow == Kill: - raise Kill() self.log("response", "debug", [repr(flow.response)]) flow = self.channel.ask("response", flow) - if flow == Kill: - raise Kill() self.send_response_to_client(flow) if self.check_close_connection(flow): @@ -315,8 +311,6 @@ class HttpLayer(Layer): # call the appropriate script hook - this is an opportunity for an # inline script to set flow.stream = True flow = self.channel.ask("responseheaders", flow) - if flow == Kill: - raise Kill() if flow.response.stream: flow.response.data.content = None @@ -352,8 +346,6 @@ class HttpLayer(Layer): flow.request.scheme = "https" if self.__initial_server_tls else "http" request_reply = self.channel.ask("request", flow) - if request_reply == Kill: - raise Kill() if isinstance(request_reply, HTTPResponse): flow.response = request_reply return diff --git a/mitmproxy/protocol/http_replay.py b/mitmproxy/protocol/http_replay.py index 62f842ce2..e78af074d 100644 --- a/mitmproxy/protocol/http_replay.py +++ b/mitmproxy/protocol/http_replay.py @@ -7,8 +7,7 @@ from netlib.http import http1 from ..controller import Channel from ..models import Error, HTTPResponse, ServerConnection, make_connect_request -from .base import Kill - +from ..exceptions import Kill # TODO: Doesn't really belong into mitmproxy.protocol... @@ -16,14 +15,14 @@ from .base import Kill class RequestReplayThread(threading.Thread): name = "RequestReplayThread" - def __init__(self, config, flow, masterq, should_exit): + def __init__(self, config, flow, event_queue, should_exit): """ - masterqueue can be a queue or None, if no scripthooks should be + event_queue can be a queue or None, if no scripthooks should be processed. """ self.config, self.flow = config, flow - if masterq: - self.channel = Channel(masterq, should_exit) + if event_queue: + self.channel = Channel(event_queue, should_exit) else: self.channel = None super(RequestReplayThread, self).__init__() @@ -37,9 +36,7 @@ class RequestReplayThread(threading.Thread): # If we have a channel, run script hooks. if self.channel: request_reply = self.channel.ask("request", self.flow) - if request_reply == Kill: - raise Kill() - elif isinstance(request_reply, HTTPResponse): + if isinstance(request_reply, HTTPResponse): self.flow.response = request_reply if not self.flow.response: diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 8b62ff938..4304bd0be 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -8,8 +8,7 @@ import six from netlib import tcp from netlib.exceptions import TcpException from netlib.http.http1 import assemble_response -from ..exceptions import ProtocolException, ServerException, ClientHandshakeException -from ..protocol import Kill +from ..exceptions import ProtocolException, ServerException, ClientHandshakeException, Kill from ..models import ClientConnection, make_error_response from .modes import HttpUpstreamProxy, HttpProxy, ReverseProxy, TransparentProxy, Socks5Proxy from .root_context import RootContext, Log @@ -21,7 +20,10 @@ class DummyServer: def __init__(self, config): self.config = config - def start_slave(self, *args): + def set_channel(self, channel): + pass + + def serve_forever(self): pass def shutdown(self): @@ -47,10 +49,6 @@ class ProxyServer(tcp.TCPServer): ) self.channel = None - def start_slave(self, klass, channel): - slave = klass(channel, self) - slave.start() - def set_channel(self, channel): self.channel = channel @@ -112,12 +110,9 @@ class ConnectionHandler(object): self.log("clientconnect", "info") root_layer = self._create_root_layer() - root_layer = self.channel.ask("clientconnect", root_layer) - if root_layer == Kill: - def root_layer(): - raise Kill() try: + root_layer = self.channel.ask("clientconnect", root_layer) root_layer() except Kill: self.log("Connection killed", "info") diff --git a/mitmproxy/web/__init__.py b/mitmproxy/web/__init__.py index 852d9fc56..4ad0b0822 100644 --- a/mitmproxy/web/__init__.py +++ b/mitmproxy/web/__init__.py @@ -173,20 +173,15 @@ class WebMaster(flow.FlowMaster): if self.options.app: self.start_app(self.options.app_host, self.options.app_port) - def tick(self): - flow.FlowMaster.tick(self, self.masterq, timeout=0) - def run(self): # pragma: no cover - self.server.start_slave( - controller.Slave, - controller.Channel(self.masterq, self.should_exit) - ) + iol = tornado.ioloop.IOLoop.instance() http_server = tornado.httpserver.HTTPServer(self.app) http_server.listen(self.options.wport) - tornado.ioloop.PeriodicCallback(self.tick, 5).start() + iol.add_callback(self.start) + tornado.ioloop.PeriodicCallback(lambda: self.tick(timeout=0), 5).start() try: iol.start() except (Stop, KeyboardInterrupt): diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py index eb3f7df4a..f7bf615a6 100644 --- a/test/mitmproxy/test_controller.py +++ b/test/mitmproxy/test_controller.py @@ -1,11 +1,105 @@ -import mock -from mitmproxy import controller +from threading import Thread, Event + +from mock import Mock + +from mitmproxy.controller import Reply, DummyReply, Channel, ServerThread, ServerMaster, Master +from six.moves import queue + +from mitmproxy.exceptions import Kill +from mitmproxy.proxy import DummyServer +from netlib.tutils import raises -class TestMaster: +class TestMaster(object): + def test_simple(self): - def test_default_handler(self): - m = controller.Master(None) - msg = mock.MagicMock() - m.handle("type", msg) - assert msg.reply.call_count == 1 + class DummyMaster(Master): + def handle_panic(self, _): + m.should_exit.set() + + def tick(self, timeout): + # Speed up test + super(DummyMaster, self).tick(0) + + m = DummyMaster() + assert not m.should_exit.is_set() + m.event_queue.put(("panic", 42)) + m.run() + assert m.should_exit.is_set() + + +class TestServerMaster(object): + def test_simple(self): + m = ServerMaster() + s = DummyServer(None) + m.add_server(s) + m.start() + m.shutdown() + m.start() + m.shutdown() + + +class TestServerThread(object): + def test_simple(self): + m = Mock() + t = ServerThread(m) + t.run() + assert m.serve_forever.called + + +class TestChannel(object): + def test_tell(self): + q = queue.Queue() + channel = Channel(q, Event()) + m = Mock() + channel.tell("test", m) + assert q.get() == ("test", m) + assert m.reply + + def test_ask_simple(self): + q = queue.Queue() + + def reply(): + m, obj = q.get() + assert m == "test" + obj.reply(42) + + Thread(target=reply).start() + + channel = Channel(q, Event()) + assert channel.ask("test", Mock()) == 42 + + def test_ask_shutdown(self): + q = queue.Queue() + done = Event() + done.set() + channel = Channel(q, done) + with raises(Kill): + channel.ask("test", Mock()) + + +class TestDummyReply(object): + def test_simple(self): + reply = DummyReply() + assert not reply.acked + reply() + assert reply.acked + + +class TestReply(object): + def test_simple(self): + reply = Reply(42) + assert not reply.acked + reply("foo") + assert reply.acked + assert reply.q.get() == "foo" + + def test_default(self): + reply = Reply(42) + reply() + assert reply.q.get() == 42 + + def test_reply_none(self): + reply = Reply(42) + reply(None) + assert reply.q.get() is None diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 7ede7a81c..60f6b1a92 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -116,9 +116,8 @@ class TestClientPlaybackState: c.clear(c.current) assert c.done() - q = queue.Queue() fm.state.clear() - fm.tick(q, timeout=0) + fm.tick(timeout=0) fm.stop_client_playback() assert not fm.client_playback @@ -858,9 +857,8 @@ class TestFlowMaster: assert not fm.start_client_playback(pb, False) fm.client_playback.testing = True - q = queue.Queue() assert not fm.state.flow_count() - fm.tick(q, 0) + fm.tick(0) assert fm.state.flow_count() f.error = Error("error") @@ -904,8 +902,7 @@ class TestFlowMaster: assert not fm.do_server_playback(r) assert fm.do_server_playback(tutils.tflow()) - q = queue.Queue() - fm.tick(q, 0) + fm.tick(0) assert fm.should_exit.is_set() fm.stop_server_playback() diff --git a/test/mitmproxy/test_proxy.py b/test/mitmproxy/test_proxy.py index fddb851e8..e08971356 100644 --- a/test/mitmproxy/test_proxy.py +++ b/test/mitmproxy/test_proxy.py @@ -175,7 +175,7 @@ class TestDummyServer: def test_simple(self): d = DummyServer(None) - d.start_slave() + d.set_channel(None) d.shutdown() diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 0d56e7ea5..8843ee62c 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -13,7 +13,7 @@ from netlib.tutils import raises from pathod import pathoc, pathod from mitmproxy.proxy.config import HostMatcher -from mitmproxy.protocol import Kill +from mitmproxy.exceptions import Kill from mitmproxy.models import Error, HTTPResponse from . import tutils, tservers @@ -126,7 +126,7 @@ class TcpMixin: i2 = self.pathod("306") self._ignore_off() - self.master.masterq.join() + self.master.event_queue.join() assert n.status_code == 304 assert i.status_code == 305 @@ -172,7 +172,7 @@ class TcpMixin: i2 = self.pathod("306") self._tcpproxy_off() - self.master.masterq.join() + self.master.event_queue.join() assert n.status_code == 304 assert i.status_code == 305