diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 55a4dbcfb..5acbebf21 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -9,7 +9,7 @@ import cookielib import os import re import urlparse - +import inspect from netlib import wsgi from netlib.exceptions import HttpException @@ -21,6 +21,11 @@ from .proxy.config import HostMatcher from .protocol.http_replay import RequestReplayThread from .protocol import Kill from .models import ClientConnection, ServerConnection, HTTPResponse, HTTPFlow, HTTPRequest +from . import contentviews as cv + + +class PluginError(Exception): + pass class AppRegistry: @@ -614,6 +619,43 @@ class State(object): self.flows.kill_all(master) +class Plugins(object): + def __init__(self): + self._view_plugins = {} + + def __iter__(self): + for plugin_type in ('view_plugins',): + yield (plugin_type, getattr(self, '_' + plugin_type)) + + def __getitem__(self, key): + if key in ('view_plugins',): + return getattr(self, '_' + key) + else: + return None + + def register_view(self, id, **kwargs): + if self._view_plugins.get(id): + raise PluginError("Duplicate view registration for %s" % (id, )) + + if not kwargs.get('class_ref') or not \ + callable(kwargs['class_ref']) or not \ + isinstance(kwargs['class_ref'], type): + raise PluginError("No custom content view class passed for view %s" % (id, )) + + script_path = inspect.stack()[1][1] + + view_plugin = { + 'title': kwargs.get('title') or id, + 'class_ref': kwargs['class_ref'], + 'script_path': script_path, + } + self._view_plugins[id] = view_plugin + + cv.add(kwargs['class_ref']()) + + print("Registered view plugin %s from script %s" % (kwargs['title'], script_path)) + + class FlowMaster(controller.Master): def __init__(self, server, state): controller.Master.__init__(self, server) @@ -643,6 +685,8 @@ class FlowMaster(controller.Master): self.stream = None self.apps = AppRegistry() + self.plugins = Plugins() + def start_app(self, host, port): self.apps.add( app.mapp, diff --git a/libmproxy/script.py b/libmproxy/script.py index 9d051c129..f11c5cd87 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -56,6 +56,13 @@ class ScriptContext: def app_registry(self): return self._master.apps + @property + def plugins(self): + if hasattr(self._master, 'plugins'): + return self._master.plugins + + return None + class Script: """ diff --git a/test/test_custom_contentview.py b/test/test_custom_contentview.py new file mode 100644 index 000000000..2ca184d0a --- /dev/null +++ b/test/test_custom_contentview.py @@ -0,0 +1,49 @@ +from libmproxy import script, flow +import libmproxy.contentviews as cv +from netlib.http import Headers + + +def test_custom_views(): + plugins = flow.Plugins() + + # two types: view and action + assert 'view_plugins' in dict(plugins).keys() + + view_plugins = plugins['view_plugins'] + assert len(view_plugins) == 0 + + class ViewNoop(cv.View): + name = "noop" + prompt = ("noop", "n") + content_types = ["text/none"] + + def __call__(self, data, **metadata): + return "noop", cv.format_text(data) + + plugins.register_view('noop', + title='Noop View Plugin', + class_ref=ViewNoop) + + assert len(view_plugins) == 1 + assert view_plugins['noop']['title'] == 'Noop View Plugin' + + assert cv.get("noop") + + r = cv.get_content_view( + cv.get("noop"), + "[1, 2, 3]", + headers=Headers( + content_type="text/plain" + ) + ) + assert "noop" in r[0] + + # now try content-type matching + r = cv.get_content_view( + cv.get("Auto"), + "[1, 2, 3]", + headers=Headers( + content_type="text/none" + ) + ) + assert "noop" in r[0] diff --git a/test/test_script.py b/test/test_script.py index 1b0e5a5b4..f0883ad5e 100644 --- a/test/test_script.py +++ b/test/test_script.py @@ -127,3 +127,13 @@ def test_command_parsing(): absfilepath = os.path.normcase(tutils.test_data.path("scripts/a.py")) s = script.Script(absfilepath, fm) assert os.path.isfile(s.args[0]) + + +def test_script_plugins(): + s = flow.State() + fm = flow.FlowMaster(None, s) + sp = tutils.test_data.path("scripts/a.py") + p = script.Script("%s --var 40" % sp, fm) + + assert hasattr(p.ctx, 'plugins') +