Merge pull request #1356 from cortesi/script

Scripts to addon
This commit is contained in:
Aldo Cortesi 2016-07-15 16:48:01 +12:00 committed by GitHub
commit 64e16f5112
35 changed files with 658 additions and 690 deletions

View File

@ -2,7 +2,7 @@
This inline script utilizes harparser.HAR from This inline script utilizes harparser.HAR from
https://github.com/JustusW/harparser to generate a HAR log object. https://github.com/JustusW/harparser to generate a HAR log object.
""" """
import mitmproxy import mitmproxy.ctx
import six import six
import sys import sys
import pytz import pytz
@ -221,9 +221,11 @@ def done():
if context.dump_file == '-': if context.dump_file == '-':
mitmproxy.ctx.log(pprint.pformat(json.loads(json_dump))) mitmproxy.ctx.log(pprint.pformat(json.loads(json_dump)))
elif context.dump_file.endswith('.zhar'): elif context.dump_file.endswith('.zhar'):
file(context.dump_file, "w").write(compressed_json_dump) with open(context.dump_file, "wb") as f:
f.write(compressed_json_dump)
else: else:
file(context.dump_file, "w").write(json_dump) with open(context.dump_file, "wb") as f:
f.write(json_dump)
mitmproxy.ctx.log( mitmproxy.ctx.log(
"HAR log finished with %s bytes (%s bytes compressed)" % ( "HAR log finished with %s bytes (%s bytes compressed)" % (
len(json_dump), len(compressed_json_dump) len(json_dump), len(compressed_json_dump)

View File

@ -6,11 +6,18 @@ import mitmproxy
def start(): def start():
""" """
Called once on script startup, before any other events. Called once on script startup before any other events
""" """
mitmproxy.ctx.log("start") mitmproxy.ctx.log("start")
def configure(options):
"""
Called once on script startup before any other events, and whenever options changes.
"""
mitmproxy.ctx.log("configure")
def clientconnect(root_layer): def clientconnect(root_layer):
""" """
Called when a client initiates a connection to the proxy. Note that a Called when a client initiates a connection to the proxy. Note that a

View File

@ -21,6 +21,7 @@ class Addons(object):
def add(self, *addons): def add(self, *addons):
self.chain.extend(addons) self.chain.extend(addons)
for i in addons: for i in addons:
self.invoke_with_context(i, "start")
self.invoke_with_context(i, "configure", self.master.options) self.invoke_with_context(i, "configure", self.master.options)
def remove(self, addon): def remove(self, addon):

View File

@ -4,6 +4,7 @@ from mitmproxy.builtins import anticache
from mitmproxy.builtins import anticomp from mitmproxy.builtins import anticomp
from mitmproxy.builtins import stickyauth from mitmproxy.builtins import stickyauth
from mitmproxy.builtins import stickycookie from mitmproxy.builtins import stickycookie
from mitmproxy.builtins import script
from mitmproxy.builtins import stream from mitmproxy.builtins import stream
@ -13,5 +14,6 @@ def default_addons():
anticomp.AntiComp(), anticomp.AntiComp(),
stickyauth.StickyAuth(), stickyauth.StickyAuth(),
stickycookie.StickyCookie(), stickycookie.StickyCookie(),
script.ScriptLoader(),
stream.Stream(), stream.Stream(),
] ]

View File

@ -0,0 +1,185 @@
from __future__ import absolute_import, print_function, division
import contextlib
import os
import shlex
import sys
import threading
import traceback
from mitmproxy import exceptions
from mitmproxy import controller
from mitmproxy import ctx
import watchdog.events
from watchdog.observers import polling
def parse_command(command):
"""
Returns a (path, args) tuple.
"""
if not command or not command.strip():
raise exceptions.AddonError("Empty script command.")
# Windows: escape all backslashes in the path.
if os.name == "nt": # pragma: no cover
backslashes = shlex.split(command, posix=False)[0].count("\\")
command = command.replace("\\", "\\\\", backslashes)
args = shlex.split(command) # pragma: no cover
args[0] = os.path.expanduser(args[0])
if not os.path.exists(args[0]):
raise exceptions.AddonError(
("Script file not found: %s.\r\n"
"If your script path contains spaces, "
"make sure to wrap it in additional quotes, e.g. -s \"'./foo bar/baz.py' --args\".") %
args[0])
elif os.path.isdir(args[0]):
raise exceptions.AddonError("Not a file: %s" % args[0])
return args[0], args[1:]
@contextlib.contextmanager
def scriptenv(path, args):
oldargs = sys.argv
sys.argv = [path] + args
script_dir = os.path.dirname(os.path.abspath(path))
sys.path.append(script_dir)
try:
yield
except Exception:
_, _, tb = sys.exc_info()
scriptdir = os.path.dirname(os.path.abspath(path))
for i, s in enumerate(reversed(traceback.extract_tb(tb))):
tb = tb.tb_next
if not os.path.abspath(s[0]).startswith(scriptdir):
break
ctx.log.error("Script error: %s" % "".join(traceback.format_tb(tb)))
finally:
sys.argv = oldargs
sys.path.pop()
def load_script(path, args):
with open(path, "rb") as f:
try:
code = compile(f.read(), path, 'exec')
except SyntaxError as e:
ctx.log.error(
"Script error: %s line %s: %s" % (
e.filename, e.lineno, e.msg
)
)
return
ns = {'__file__': os.path.abspath(path)}
with scriptenv(path, args):
exec(code, ns, ns)
return ns
class ReloadHandler(watchdog.events.FileSystemEventHandler):
def __init__(self, callback):
self.callback = callback
def on_modified(self, event):
self.callback()
def on_created(self, event):
self.callback()
class Script:
"""
An addon that manages a single script.
"""
def __init__(self, command):
self.name = command
self.command = command
self.path, self.args = parse_command(command)
self.ns = None
self.observer = None
self.dead = False
self.last_options = None
self.should_reload = threading.Event()
for i in controller.Events:
if not hasattr(self, i):
def mkprox():
evt = i
def prox(*args, **kwargs):
self.run(evt, *args, **kwargs)
return prox
setattr(self, i, mkprox())
def run(self, name, *args, **kwargs):
# It's possible for ns to be un-initialised if we failed during
# configure
if self.ns is not None and not self.dead:
func = self.ns.get(name)
if func:
with scriptenv(self.path, self.args):
func(*args, **kwargs)
def reload(self):
self.should_reload.set()
def tick(self):
if self.should_reload.is_set():
self.should_reload.clear()
ctx.log.info("Reloading script: %s" % self.name)
self.ns = load_script(self.path, self.args)
self.configure(self.last_options)
else:
self.run("tick")
def start(self):
self.ns = load_script(self.path, self.args)
self.run("start")
def configure(self, options):
self.last_options = options
if not self.observer:
self.observer = polling.PollingObserver()
# Bind the handler to the real underlying master object
self.observer.schedule(
ReloadHandler(self.reload),
os.path.dirname(self.path) or "."
)
self.observer.start()
self.run("configure", options)
def done(self):
self.run("done")
self.dead = True
class ScriptLoader():
"""
An addon that manages loading scripts from options.
"""
def configure(self, options):
for s in options.scripts:
if options.scripts.count(s) > 1:
raise exceptions.OptionsError("Duplicate script: %s" % s)
for a in ctx.master.addons.chain[:]:
if isinstance(a, Script) and a.name not in options.scripts:
ctx.log.info("Un-loading script: %s" % a.name)
ctx.master.addons.remove(a)
current = {}
for a in ctx.master.addons.chain[:]:
if isinstance(a, Script):
current[a.name] = a
ctx.master.addons.chain.remove(a)
for s in options.scripts:
if s in current:
ctx.master.addons.chain.append(current[s])
else:
ctx.log.info("Loading script: %s" % s)
sc = Script(s)
ctx.master.addons.add(sc)

View File

@ -6,11 +6,12 @@ import re
import urwid import urwid
from mitmproxy import exceptions
from mitmproxy import filt from mitmproxy import filt
from mitmproxy import script from mitmproxy.builtins import script
from mitmproxy import utils
from mitmproxy.console import common from mitmproxy.console import common
from mitmproxy.console import signals from mitmproxy.console import signals
from netlib import strutils
from netlib.http import cookies from netlib.http import cookies
from netlib.http import user_agents from netlib.http import user_agents
@ -55,7 +56,7 @@ class TextColumn:
o = editor.walker.get_current_value() o = editor.walker.get_current_value()
if o is not None: if o is not None:
n = editor.master.spawn_editor(o.encode("string-escape")) n = editor.master.spawn_editor(o.encode("string-escape"))
n = utils.clean_hanging_newline(n) n = strutils.clean_hanging_newline(n)
editor.walker.set_current_value(n, False) editor.walker.set_current_value(n, False)
editor.walker._modified() editor.walker._modified()
elif key in ["enter"]: elif key in ["enter"]:
@ -643,8 +644,8 @@ class ScriptEditor(GridEditor):
def is_error(self, col, val): def is_error(self, col, val):
try: try:
script.Script.parse_command(val) script.parse_command(val)
except script.ScriptException as e: except exceptions.AddonError as e:
return str(e) return str(e)

View File

@ -248,23 +248,6 @@ class ConsoleMaster(flow.FlowMaster):
if options.server_replay: if options.server_replay:
self.server_playback_path(options.server_replay) self.server_playback_path(options.server_replay)
if options.scripts:
for i in options.scripts:
try:
self.load_script(i)
except exceptions.ScriptException as e:
print("Script load error: {}".format(e), file=sys.stderr)
sys.exit(1)
if options.outfile:
err = self.start_stream_to_path(
options.outfile[0],
options.outfile[1]
)
if err:
print("Stream file error: {}".format(err), file=sys.stderr)
sys.exit(1)
self.view_stack = [] self.view_stack = []
if options.app: if options.app:
@ -685,20 +668,7 @@ class ConsoleMaster(flow.FlowMaster):
self.refresh_focus() self.refresh_focus()
def edit_scripts(self, scripts): def edit_scripts(self, scripts):
commands = [x[0] for x in scripts] # remove outer array self.options.scripts = [x[0] for x in scripts]
if commands == [s.command for s in self.scripts]:
return
self.unload_scripts()
for command in commands:
try:
self.load_script(command)
except exceptions.ScriptException as e:
signals.status_message.send(
message='Error loading "{}".'.format(command)
)
signals.add_event('Error loading "{}":\n{}'.format(command, e), "error")
signals.update_settings.send(self)
def stop_client_playback_prompt(self, a): def stop_client_playback_prompt(self, a):
if a != "n": if a != "n":

View File

@ -54,7 +54,7 @@ class Options(urwid.WidgetWrap):
select.Option( select.Option(
"Scripts", "Scripts",
"S", "S",
lambda: master.scripts, lambda: master.options.scripts,
self.scripts self.scripts
), ),
@ -160,12 +160,14 @@ class Options(urwid.WidgetWrap):
self.master.replacehooks.clear() self.master.replacehooks.clear()
self.master.set_ignore_filter([]) self.master.set_ignore_filter([])
self.master.set_tcp_filter([]) self.master.set_tcp_filter([])
self.master.scripts = []
self.master.options.anticache = False self.master.options.update(
self.master.options.anticomp = False scripts = [],
self.master.options.stickyauth = None anticache = False,
self.master.options.stickycookie = None anticomp = False,
stickyauth = None,
stickycookie = None
)
self.master.state.default_body_view = contentviews.get("Auto") self.master.state.default_body_view = contentviews.get("Auto")
@ -234,7 +236,7 @@ class Options(urwid.WidgetWrap):
self.master.view_grideditor( self.master.view_grideditor(
grideditor.ScriptEditor( grideditor.ScriptEditor(
self.master, self.master,
[[i.command] for i in self.master.scripts], [[i] for i in self.master.options.scripts],
self.master.edit_scripts self.master.edit_scripts
) )
) )

View File

@ -218,14 +218,13 @@ class StatusBar(urwid.WidgetWrap):
dst.address.host, dst.address.host,
dst.address.port dst.address.port
)) ))
if self.master.scripts: if self.master.options.scripts:
r.append("[") r.append("[")
r.append(("heading_key", "s")) r.append(("heading_key", "s"))
r.append("cripts:%s]" % len(self.master.scripts)) r.append("cripts:%s]" % len(self.master.options.scripts))
# r.append("[lt:%0.3f]"%self.master.looptime)
if self.master.stream: if self.master.options.outfile:
r.append("[W:%s]" % self.master.stream_path) r.append("[W:%s]" % self.master.outfile[0])
return r return r

View File

@ -33,6 +33,11 @@ Events = frozenset([
"error", "error",
"log", "log",
"start",
"configure",
"done",
"tick",
"script_change", "script_change",
]) ])
@ -44,7 +49,17 @@ class Log(object):
def __call__(self, text, level="info"): def __call__(self, text, level="info"):
self.master.add_event(text, level) self.master.add_event(text, level)
# We may want to add .log(), .warn() etc. here at a later point in time def debug(self, txt):
self(txt, "debug")
def info(self, txt):
self(txt, "info")
def warn(self, txt):
self(txt, "warn")
def error(self, txt):
self(txt, "error")
class Master(object): class Master(object):
@ -97,26 +112,25 @@ class Master(object):
self.shutdown() self.shutdown()
def tick(self, timeout): def tick(self, timeout):
with self.handlecontext():
self.addons("tick")
changed = False changed = False
try: try:
# This endless loop runs until the 'Queue.Empty' mtype, obj = self.event_queue.get(timeout=timeout)
# exception is thrown. if mtype not in Events:
while True: raise exceptions.ControlException("Unknown event %s" % repr(mtype))
mtype, obj = self.event_queue.get(timeout=timeout) handle_func = getattr(self, mtype)
if mtype not in Events: if not callable(handle_func):
raise exceptions.ControlException("Unknown event %s" % repr(mtype)) raise exceptions.ControlException("Handler %s not callable" % mtype)
handle_func = getattr(self, mtype) if not handle_func.__dict__.get("__handler"):
if not callable(handle_func): raise exceptions.ControlException(
raise exceptions.ControlException("Handler %s not callable" % mtype) "Handler function %s is not decorated with controller.handler" % (
if not handle_func.__dict__.get("__handler"): handle_func
raise exceptions.ControlException(
"Handler function %s is not decorated with controller.handler" % (
handle_func
)
) )
handle_func(obj) )
self.event_queue.task_done() handle_func(obj)
changed = True self.event_queue.task_done()
changed = True
except queue.Empty: except queue.Empty:
pass pass
return changed return changed

View File

@ -93,13 +93,6 @@ class DumpMaster(flow.FlowMaster):
not options.keepserving not options.keepserving
) )
scripts = options.scripts or []
for command in scripts:
try:
self.load_script(command, use_reloader=True)
except exceptions.ScriptException as e:
raise DumpError(str(e))
if options.rfile: if options.rfile:
try: try:
self.load_flows_file(options.rfile) self.load_flows_file(options.rfile)
@ -335,6 +328,5 @@ class DumpMaster(flow.FlowMaster):
def run(self): # pragma: no cover def run(self): # pragma: no cover
if self.options.rfile and not self.options.keepserving: if self.options.rfile and not self.options.keepserving:
self.unload_scripts() # make sure to trigger script unload events.
return return
super(DumpMaster, self).run() super(DumpMaster, self).run()

View File

@ -99,3 +99,7 @@ class ControlException(ProxyException):
class OptionsError(Exception): class OptionsError(Exception):
pass pass
class AddonError(Exception):
pass

View File

@ -9,7 +9,6 @@ import netlib.exceptions
from mitmproxy import controller from mitmproxy import controller
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import models from mitmproxy import models
from mitmproxy import script
from mitmproxy.flow import io from mitmproxy.flow import io
from mitmproxy.flow import modules from mitmproxy.flow import modules
from mitmproxy.onboarding import app from mitmproxy.onboarding import app
@ -35,8 +34,6 @@ class FlowMaster(controller.Master):
self.server_playback = None # type: Optional[modules.ServerPlaybackState] self.server_playback = None # type: Optional[modules.ServerPlaybackState]
self.client_playback = None # type: Optional[modules.ClientPlaybackState] self.client_playback = None # type: Optional[modules.ClientPlaybackState]
self.kill_nonreplay = False self.kill_nonreplay = False
self.scripts = [] # type: List[script.Script]
self.pause_scripts = False
self.stream_large_bodies = None # type: Optional[modules.StreamLargeBodies] self.stream_large_bodies = None # type: Optional[modules.StreamLargeBodies]
self.refresh_server_playback = False self.refresh_server_playback = False
@ -60,44 +57,6 @@ class FlowMaster(controller.Master):
level: debug, info, error level: debug, info, error
""" """
def unload_scripts(self):
for s in self.scripts[:]:
self.unload_script(s)
def unload_script(self, script_obj):
try:
script_obj.unload()
except script.ScriptException as e:
self.add_event("Script error:\n" + str(e), "error")
script.reloader.unwatch(script_obj)
self.scripts.remove(script_obj)
def load_script(self, command, use_reloader=False):
"""
Loads a script.
Raises:
ScriptException
"""
s = script.Script(command)
s.load()
if use_reloader:
s.reply = controller.DummyReply()
script.reloader.watch(s, lambda: self.event_queue.put(("script_change", s)))
self.scripts.append(s)
def _run_single_script_hook(self, script_obj, name, *args, **kwargs):
if script_obj and not self.pause_scripts:
try:
script_obj.run(name, *args, **kwargs)
except script.ScriptException as e:
self.add_event("Script error:\n{}".format(e), "error")
def run_scripts(self, name, msg):
for script_obj in self.scripts:
if not msg.reply.acked:
self._run_single_script_hook(script_obj, name, msg)
def get_ignore_filter(self): def get_ignore_filter(self):
return self.server.config.check_ignore.patterns return self.server.config.check_ignore.patterns
@ -298,11 +257,11 @@ class FlowMaster(controller.Master):
if not pb and self.kill_nonreplay: if not pb and self.kill_nonreplay:
f.kill(self) f.kill(self)
def replay_request(self, f, block=False, run_scripthooks=True): def replay_request(self, f, block=False):
""" """
Returns None if successful, or error message if not. Returns None if successful, or error message if not.
""" """
if f.live and run_scripthooks: if f.live:
return "Can't replay live request." return "Can't replay live request."
if f.intercepted: if f.intercepted:
return "Can't replay while intercepting..." return "Can't replay while intercepting..."
@ -319,7 +278,7 @@ class FlowMaster(controller.Master):
rt = http_replay.RequestReplayThread( rt = http_replay.RequestReplayThread(
self.server.config, self.server.config,
f, f,
self.event_queue if run_scripthooks else False, self.event_queue,
self.should_exit self.should_exit
) )
rt.start() # pragma: no cover rt.start() # pragma: no cover
@ -332,28 +291,27 @@ class FlowMaster(controller.Master):
@controller.handler @controller.handler
def clientconnect(self, root_layer): def clientconnect(self, root_layer):
self.run_scripts("clientconnect", root_layer) pass
@controller.handler @controller.handler
def clientdisconnect(self, root_layer): def clientdisconnect(self, root_layer):
self.run_scripts("clientdisconnect", root_layer) pass
@controller.handler @controller.handler
def serverconnect(self, server_conn): def serverconnect(self, server_conn):
self.run_scripts("serverconnect", server_conn) pass
@controller.handler @controller.handler
def serverdisconnect(self, server_conn): def serverdisconnect(self, server_conn):
self.run_scripts("serverdisconnect", server_conn) pass
@controller.handler @controller.handler
def next_layer(self, top_layer): def next_layer(self, top_layer):
self.run_scripts("next_layer", top_layer) pass
@controller.handler @controller.handler
def error(self, f): def error(self, f):
self.state.update_flow(f) self.state.update_flow(f)
self.run_scripts("error", f)
if self.client_playback: if self.client_playback:
self.client_playback.clear(f) self.client_playback.clear(f)
return f return f
@ -381,8 +339,6 @@ class FlowMaster(controller.Master):
self.setheaders.run(f) self.setheaders.run(f)
if not f.reply.acked: if not f.reply.acked:
self.process_new_request(f) self.process_new_request(f)
if not f.reply.acked:
self.run_scripts("request", f)
return f return f
@controller.handler @controller.handler
@ -393,7 +349,6 @@ class FlowMaster(controller.Master):
except netlib.exceptions.HttpException: except netlib.exceptions.HttpException:
f.reply.kill() f.reply.kill()
return return
self.run_scripts("responseheaders", f)
return f return f
@controller.handler @controller.handler
@ -404,7 +359,6 @@ class FlowMaster(controller.Master):
self.replacehooks.run(f) self.replacehooks.run(f)
if not f.reply.acked: if not f.reply.acked:
self.setheaders.run(f) self.setheaders.run(f)
self.run_scripts("response", f)
if not f.reply.acked: if not f.reply.acked:
if self.client_playback: if self.client_playback:
self.client_playback.clear(f) self.client_playback.clear(f)
@ -416,46 +370,15 @@ class FlowMaster(controller.Master):
def handle_accept_intercept(self, f): def handle_accept_intercept(self, f):
self.state.update_flow(f) self.state.update_flow(f)
@controller.handler
def script_change(self, s):
"""
Handle a script whose contents have been changed on the file system.
Args:
s (script.Script): the changed script
Returns:
True, if reloading was successful.
False, otherwise.
"""
ok = True
# We deliberately do not want to fail here.
# In the worst case, we have an "empty" script object.
try:
s.unload()
except script.ScriptException as e:
ok = False
self.add_event('Error reloading "{}":\n{}'.format(s.path, e), 'error')
try:
s.load()
except script.ScriptException as e:
ok = False
self.add_event('Error reloading "{}":\n{}'.format(s.path, e), 'error')
else:
self.add_event('"{}" reloaded.'.format(s.path), 'info')
return ok
@controller.handler @controller.handler
def tcp_open(self, flow): def tcp_open(self, flow):
# TODO: This would break mitmproxy currently. # TODO: This would break mitmproxy currently.
# self.state.add_flow(flow) # self.state.add_flow(flow)
self.active_flows.add(flow) self.active_flows.add(flow)
self.run_scripts("tcp_open", flow)
@controller.handler @controller.handler
def tcp_message(self, flow): def tcp_message(self, flow):
# type: (TCPFlow) -> None pass
self.run_scripts("tcp_message", flow)
@controller.handler @controller.handler
def tcp_error(self, flow): def tcp_error(self, flow):
@ -463,13 +386,7 @@ class FlowMaster(controller.Master):
repr(flow.server_conn.address), repr(flow.server_conn.address),
flow.error flow.error
), "info") ), "info")
self.run_scripts("tcp_error", flow)
@controller.handler @controller.handler
def tcp_close(self, flow): def tcp_close(self, flow):
self.active_flows.discard(flow) self.active_flows.discard(flow)
self.run_scripts("tcp_close", flow)
def shutdown(self):
super(FlowMaster, self).shutdown()
self.unload_scripts()

View File

@ -1,11 +1,5 @@
from . import reloader
from .concurrent import concurrent from .concurrent import concurrent
from .script import Script
from ..exceptions import ScriptException
__all__ = [ __all__ = [
"Script",
"concurrent", "concurrent",
"ScriptException",
"reloader"
] ]

View File

@ -13,7 +13,7 @@ class ScriptThread(basethread.BaseThread):
def concurrent(fn): def concurrent(fn):
if fn.__name__ not in controller.Events: if fn.__name__ not in controller.Events - set(["start", "configure", "tick"]):
raise NotImplementedError( raise NotImplementedError(
"Concurrent decorator not supported for '%s' method." % fn.__name__ "Concurrent decorator not supported for '%s' method." % fn.__name__
) )

View File

@ -1,47 +0,0 @@
from __future__ import absolute_import, print_function, division
import os
from watchdog.events import RegexMatchingEventHandler
from watchdog.observers.polling import PollingObserver as Observer
# We occasionally have watchdog errors on Windows, Linux and Mac when using the native observers.
# After reading through the watchdog source code and issue tracker,
# we may want to replace this with a very simple implementation of our own.
_observers = {}
def watch(script, callback):
if script in _observers:
raise RuntimeError("Script already observed")
script_dir = os.path.dirname(os.path.abspath(script.path))
script_name = os.path.basename(script.path)
event_handler = _ScriptModificationHandler(callback, filename=script_name)
observer = Observer()
observer.schedule(event_handler, script_dir)
observer.start()
_observers[script] = observer
def unwatch(script):
observer = _observers.pop(script, None)
if observer:
observer.stop()
observer.join()
class _ScriptModificationHandler(RegexMatchingEventHandler):
def __init__(self, callback, filename='.*'):
super(_ScriptModificationHandler, self).__init__(
ignore_directories=True,
regexes=['.*' + filename]
)
self.callback = callback
def on_modified(self, event):
self.callback()
__all__ = ["watch", "unwatch"]

View File

@ -1,136 +0,0 @@
"""
The script object representing mitmproxy inline scripts.
Script objects know nothing about mitmproxy or mitmproxy's API - this knowledge is provided
by the mitmproxy-specific ScriptContext.
"""
# Do not import __future__ here, this would apply transitively to the inline scripts.
from __future__ import absolute_import, print_function, division
import os
import shlex
import sys
import contextlib
import six
from typing import List # noqa
from mitmproxy import exceptions
@contextlib.contextmanager
def scriptenv(path, args):
# type: (str, List[str]) -> None
oldargs = sys.argv
script_dir = os.path.dirname(os.path.abspath(path))
sys.argv = [path] + args
sys.path.append(script_dir)
try:
yield
finally:
sys.argv = oldargs
sys.path.pop()
class Script(object):
"""
Script object representing an inline script.
"""
def __init__(self, command):
self.command = command
self.path, self.args = self.parse_command(command)
self.ns = None
def __enter__(self):
self.load()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val:
return False # re-raise the exception
self.unload()
@staticmethod
def parse_command(command):
# type: (str) -> Tuple[str,List[str]]
"""
Returns a (path, args) tuple.
"""
if not command or not command.strip():
raise exceptions.ScriptException("Empty script command.")
# Windows: escape all backslashes in the path.
if os.name == "nt": # pragma: no cover
backslashes = shlex.split(command, posix=False)[0].count("\\")
command = command.replace("\\", "\\\\", backslashes)
args = shlex.split(command) # pragma: no cover
args[0] = os.path.expanduser(args[0])
if not os.path.exists(args[0]):
raise exceptions.ScriptException(
("Script file not found: %s.\r\n"
"If your script path contains spaces, "
"make sure to wrap it in additional quotes, e.g. -s \"'./foo bar/baz.py' --args\".") %
args[0])
elif os.path.isdir(args[0]):
raise exceptions.ScriptException("Not a file: %s" % args[0])
return args[0], args[1:]
def load(self):
"""
Loads an inline script.
Returns:
The return value of self.run("start", ...)
Raises:
ScriptException on failure
"""
if self.ns is not None:
raise exceptions.ScriptException("Script is already loaded")
self.ns = {'__file__': os.path.abspath(self.path)}
with scriptenv(self.path, self.args):
try:
with open(self.path) as f:
code = compile(f.read(), self.path, 'exec')
exec(code, self.ns, self.ns)
except Exception:
six.reraise(
exceptions.ScriptException,
exceptions.ScriptException.from_exception_context(),
sys.exc_info()[2]
)
return self.run("start")
def unload(self):
try:
return self.run("done")
finally:
self.ns = None
def run(self, name, *args, **kwargs):
"""
Runs an inline script hook.
Returns:
The return value of the method.
None, if the script does not provide the method.
Raises:
ScriptException if there was an exception.
"""
if self.ns is None:
raise exceptions.ScriptException("Script not loaded.")
f = self.ns.get(name)
if f:
try:
with scriptenv(self.path, self.args):
return f(*args, **kwargs)
except Exception:
six.reraise(
exceptions.ScriptException,
exceptions.ScriptException.from_exception_context(),
sys.exc_info()[2]
)
else:
return None

View File

@ -56,6 +56,13 @@ class Data(object):
dirname = os.path.dirname(inspect.getsourcefile(m)) dirname = os.path.dirname(inspect.getsourcefile(m))
self.dirname = os.path.abspath(dirname) self.dirname = os.path.abspath(dirname)
def push(self, subpath):
"""
Change the data object to a path relative to the module.
"""
self.dirname = os.path.join(self.dirname, subpath)
return self
def path(self, path): def path(self, path):
""" """
Returns a path to the package data housed at 'path' under this Returns a path to the package data housed at 'path' under this

View File

@ -0,0 +1,187 @@
import time
from mitmproxy.builtins import script
from mitmproxy import exceptions
from mitmproxy.flow import master
from mitmproxy.flow import state
from mitmproxy.flow import options
from .. import tutils, mastertest
class TestParseCommand:
def test_empty_command(self):
with tutils.raises(exceptions.AddonError):
script.parse_command("")
with tutils.raises(exceptions.AddonError):
script.parse_command(" ")
def test_no_script_file(self):
with tutils.raises("not found"):
script.parse_command("notfound")
with tutils.tmpdir() as dir:
with tutils.raises("not a file"):
script.parse_command(dir)
def test_parse_args(self):
with tutils.chdir(tutils.test_data.dirname):
assert script.parse_command("data/scripts/a.py") == ("data/scripts/a.py", [])
assert script.parse_command("data/scripts/a.py foo bar") == ("data/scripts/a.py", ["foo", "bar"])
assert script.parse_command("data/scripts/a.py 'foo bar'") == ("data/scripts/a.py", ["foo bar"])
@tutils.skip_not_windows
def test_parse_windows(self):
with tutils.chdir(tutils.test_data.dirname):
assert script.parse_command("data\\scripts\\a.py") == ("data\\scripts\\a.py", [])
assert script.parse_command("data\\scripts\\a.py 'foo \\ bar'") == ("data\\scripts\\a.py", ['foo \\ bar'])
def test_load_script():
ns = script.load_script(
tutils.test_data.path(
"data/addonscripts/recorder.py"
), []
)
assert ns["configure"]
class TestScript(mastertest.MasterTest):
def test_simple(self):
s = state.State()
m = master.FlowMaster(options.Options(), None, s)
sc = script.Script(
tutils.test_data.path(
"data/addonscripts/recorder.py"
)
)
m.addons.add(sc)
assert sc.ns["call_log"] == [
("solo", "start", (), {}),
("solo", "configure", (options.Options(),), {})
]
sc.ns["call_log"] = []
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
recf = sc.ns["call_log"][0]
assert recf[1] == "request"
def test_reload(self):
s = state.State()
m = mastertest.RecordingMaster(options.Options(), None, s)
with tutils.tmpdir():
with open("foo.py", "w"):
pass
sc = script.Script("foo.py")
m.addons.add(sc)
for _ in range(100):
with open("foo.py", "a") as f:
f.write(".")
m.addons.invoke_with_context(sc, "tick")
time.sleep(0.1)
if m.event_log:
return
raise AssertionError("Change event not detected.")
def test_exception(self):
s = state.State()
m = mastertest.RecordingMaster(options.Options(), None, s)
sc = script.Script(
tutils.test_data.path("data/addonscripts/error.py")
)
m.addons.add(sc)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
assert m.event_log[0][0] == "error"
def test_duplicate_flow(self):
s = state.State()
fm = master.FlowMaster(None, None, s)
fm.addons.add(
script.Script(
tutils.test_data.path("data/addonscripts/duplicate_flow.py")
)
)
f = tutils.tflow()
fm.request(f)
assert fm.state.flow_count() == 2
assert not fm.state.view[0].request.is_replay
assert fm.state.view[1].request.is_replay
class TestScriptLoader(mastertest.MasterTest):
def test_simple(self):
s = state.State()
o = options.Options(scripts=[])
m = master.FlowMaster(o, None, s)
sc = script.ScriptLoader()
m.addons.add(sc)
assert len(m.addons) == 1
o.update(
scripts = [
tutils.test_data.path("data/addonscripts/recorder.py")
]
)
assert len(m.addons) == 2
o.update(scripts = [])
assert len(m.addons) == 1
def test_dupes(self):
s = state.State()
o = options.Options(scripts=["one", "one"])
m = master.FlowMaster(o, None, s)
sc = script.ScriptLoader()
tutils.raises(exceptions.OptionsError, m.addons.add, sc)
def test_order(self):
rec = tutils.test_data.path("data/addonscripts/recorder.py")
s = state.State()
o = options.Options(
scripts = [
"%s %s" % (rec, "a"),
"%s %s" % (rec, "b"),
"%s %s" % (rec, "c"),
]
)
m = mastertest.RecordingMaster(o, None, s)
sc = script.ScriptLoader()
m.addons.add(sc)
debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"]
assert debug == [
('debug', 'a start'), ('debug', 'a configure'),
('debug', 'b start'), ('debug', 'b configure'),
('debug', 'c start'), ('debug', 'c configure')
]
m.event_log[:] = []
o.scripts = [
"%s %s" % (rec, "c"),
"%s %s" % (rec, "a"),
"%s %s" % (rec, "b"),
]
debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"]
assert debug == [
('debug', 'c configure'),
('debug', 'a configure'),
('debug', 'b configure'),
]
m.event_log[:] = []
o.scripts = [
"%s %s" % (rec, "x"),
"%s %s" % (rec, "a"),
]
debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"]
assert debug == [
('debug', 'c done'),
('debug', 'b done'),
('debug', 'x start'),
('debug', 'x configure'),
('debug', 'a configure'),
]

View File

@ -1,7 +1,6 @@
import time import time
from mitmproxy.script import concurrent from mitmproxy.script import concurrent
@concurrent @concurrent
def request(flow): def request(flow):
time.sleep(0.1) time.sleep(0.1)

View File

@ -0,0 +1,6 @@
from mitmproxy import ctx
def request(flow):
f = ctx.master.duplicate_flow(flow)
ctx.master.replay_request(f, block=True)

View File

@ -0,0 +1,7 @@
def mkerr():
raise ValueError("Error!")
def request(flow):
mkerr()

View File

@ -0,0 +1,25 @@
from mitmproxy import controller
from mitmproxy import ctx
import sys
call_log = []
if len(sys.argv) > 1:
name = sys.argv[1]
else:
name = "solo"
# Keep a log of all possible event calls
evts = list(controller.Events) + ["configure"]
for i in evts:
def mkprox():
evt = i
def prox(*args, **kwargs):
lg = (name, evt, args, kwargs)
if evt != "log":
ctx.log.info(str(lg))
call_log.append(lg)
ctx.log.debug("%s %s" % (name, evt))
return prox
globals()[i] = mkprox()

View File

@ -0,0 +1,8 @@
def modify(chunks):
for chunk in chunks:
yield chunk.replace(b"foo", b"bar")
def responseheaders(flow):
flow.response.stream = modify

View File

@ -0,0 +1,5 @@
def tcp_message(flow):
message = flow.messages[-1]
if not message.from_client:
message.content = message.content.replace(b"foo", b"bar")

View File

@ -3,6 +3,7 @@ import mock
from . import tutils from . import tutils
import netlib.tutils import netlib.tutils
from mitmproxy.flow import master
from mitmproxy import flow, proxy, models, controller from mitmproxy import flow, proxy, models, controller
@ -39,3 +40,12 @@ class MasterTest:
t = tutils.tflow(resp=True) t = tutils.tflow(resp=True)
fw.add(t) fw.add(t)
f.close() f.close()
class RecordingMaster(master.FlowMaster):
def __init__(self, *args, **kwargs):
master.FlowMaster.__init__(self, *args, **kwargs)
self.event_log = []
def add_event(self, e, level):
self.event_log.append((level, e))

View File

@ -1,28 +1,46 @@
from mitmproxy.script import Script from test.mitmproxy import tutils, mastertest
from test.mitmproxy import tutils
from mitmproxy import controller from mitmproxy import controller
from mitmproxy.builtins import script
from mitmproxy import options
from mitmproxy.flow import master
from mitmproxy.flow import state
import time import time
class Thing: class Thing:
def __init__(self): def __init__(self):
self.reply = controller.DummyReply() self.reply = controller.DummyReply()
self.live = True
@tutils.skip_appveyor class TestConcurrent(mastertest.MasterTest):
def test_concurrent(): @tutils.skip_appveyor
with Script(tutils.test_data.path("data/scripts/concurrent_decorator.py")) as s: def test_concurrent(self):
f1, f2 = Thing(), Thing() s = state.State()
s.run("request", f1) m = master.FlowMaster(options.Options(), None, s)
s.run("request", f2) sc = script.Script(
tutils.test_data.path(
"data/addonscripts/concurrent_decorator.py"
)
)
m.addons.add(sc)
f1, f2 = tutils.tflow(), tutils.tflow()
self.invoke(m, "request", f1)
self.invoke(m, "request", f2)
start = time.time() start = time.time()
while time.time() - start < 5: while time.time() - start < 5:
if f1.reply.acked and f2.reply.acked: if f1.reply.acked and f2.reply.acked:
return return
raise ValueError("Script never acked") raise ValueError("Script never acked")
def test_concurrent_err(self):
def test_concurrent_err(): s = state.State()
s = Script(tutils.test_data.path("data/scripts/concurrent_decorator_err.py")) m = mastertest.RecordingMaster(options.Options(), None, s)
with tutils.raises("Concurrent decorator not supported for 'start' method"): sc = script.Script(
s.load() tutils.test_data.path(
"data/addonscripts/concurrent_decorator_err.py"
)
)
with m.handlecontext():
sc.start()
assert "decorator not supported" in m.event_log[0][1]

View File

@ -1,34 +0,0 @@
import mock
from mitmproxy.script.reloader import watch, unwatch
from test.mitmproxy import tutils
from threading import Event
def test_simple():
with tutils.tmpdir():
with open("foo.py", "w"):
pass
script = mock.Mock()
script.path = "foo.py"
e = Event()
def _onchange():
e.set()
watch(script, _onchange)
with tutils.raises("already observed"):
watch(script, _onchange)
# Some reloaders don't register a change directly after watching, because they first need to initialize.
# To test if watching works at all, we do repeated writes every 100ms.
for _ in range(100):
with open("foo.py", "a") as f:
f.write(".")
if e.wait(0.1):
break
else:
raise AssertionError("No change detected.")
unwatch(script)

View File

@ -1,83 +0,0 @@
from mitmproxy.script import Script
from mitmproxy.exceptions import ScriptException
from test.mitmproxy import tutils
class TestParseCommand:
def test_empty_command(self):
with tutils.raises(ScriptException):
Script.parse_command("")
with tutils.raises(ScriptException):
Script.parse_command(" ")
def test_no_script_file(self):
with tutils.raises("not found"):
Script.parse_command("notfound")
with tutils.tmpdir() as dir:
with tutils.raises("not a file"):
Script.parse_command(dir)
def test_parse_args(self):
with tutils.chdir(tutils.test_data.dirname):
assert Script.parse_command("data/scripts/a.py") == ("data/scripts/a.py", [])
assert Script.parse_command("data/scripts/a.py foo bar") == ("data/scripts/a.py", ["foo", "bar"])
assert Script.parse_command("data/scripts/a.py 'foo bar'") == ("data/scripts/a.py", ["foo bar"])
@tutils.skip_not_windows
def test_parse_windows(self):
with tutils.chdir(tutils.test_data.dirname):
assert Script.parse_command("data\\scripts\\a.py") == ("data\\scripts\\a.py", [])
assert Script.parse_command("data\\scripts\\a.py 'foo \\ bar'") == ("data\\scripts\\a.py", ['foo \\ bar'])
def test_simple():
with tutils.chdir(tutils.test_data.path("data/scripts")):
s = Script("a.py --var 42")
assert s.path == "a.py"
assert s.ns is None
s.load()
assert s.ns["var"] == 42
s.run("here")
assert s.ns["var"] == 43
s.unload()
assert s.ns is None
with tutils.raises(ScriptException):
s.run("here")
with Script("a.py --var 42") as s:
s.run("here")
def test_script_exception():
with tutils.chdir(tutils.test_data.path("data/scripts")):
s = Script("syntaxerr.py")
with tutils.raises(ScriptException):
s.load()
s = Script("starterr.py")
with tutils.raises(ScriptException):
s.load()
s = Script("a.py")
s.load()
with tutils.raises(ScriptException):
s.load()
s = Script("a.py")
with tutils.raises(ScriptException):
s.run("here")
with tutils.raises(ScriptException):
with Script("reqerr.py") as s:
s.run("request", None)
s = Script("unloaderr.py")
s.load()
with tutils.raises(ScriptException):
s.unload()

View File

@ -245,12 +245,12 @@ class TestDumpMaster(mastertest.MasterTest):
assert "XRESPONSE" in ret assert "XRESPONSE" in ret
assert "XCLIENTDISCONNECT" in ret assert "XCLIENTDISCONNECT" in ret
tutils.raises( tutils.raises(
dump.DumpError, exceptions.AddonError,
self.mkmaster, self.mkmaster,
None, scripts=["nonexistent"] None, scripts=["nonexistent"]
) )
tutils.raises( tutils.raises(
dump.DumpError, exceptions.AddonError,
self.mkmaster, self.mkmaster,
None, scripts=["starterr.py"] None, scripts=["starterr.py"]
) )

View File

@ -1,151 +1,126 @@
import glob
import json import json
import mock
import os
import sys
from contextlib import contextmanager
from mitmproxy import script import six
import sys
import os.path
from mitmproxy.flow import master
from mitmproxy.flow import state
from mitmproxy import options
from mitmproxy import contentviews
from mitmproxy.builtins import script
import netlib.utils import netlib.utils
from netlib import tutils as netutils from netlib import tutils as netutils
from netlib.http import Headers from netlib.http import Headers
from . import tutils from . import tutils, mastertest
example_dir = netlib.utils.Data(__name__).path("../../examples") example_dir = netlib.utils.Data(__name__).push("../../examples")
@contextmanager class ScriptError(Exception):
def example(command): pass
command = os.path.join(example_dir, command)
with script.Script(command) as s:
yield s
@mock.patch("mitmproxy.ctx.master") class RaiseMaster(master.FlowMaster):
@mock.patch("mitmproxy.ctx.log") def add_event(self, e, level):
def test_load_scripts(log, master): if level in ("warn", "error"):
scripts = glob.glob("%s/*.py" % example_dir) raise ScriptError(e)
for f in scripts:
if "har_extractor" in f:
continue
if "flowwriter" in f:
f += " -"
if "iframe_injector" in f:
f += " foo" # one argument required
if "filt" in f:
f += " ~a"
if "modify_response_body" in f:
f += " foo bar" # two arguments required
s = script.Script(f)
try:
s.load()
except Exception as v:
if "ImportError" not in str(v):
raise
else:
s.unload()
def test_add_header(): def tscript(cmd, args=""):
flow = tutils.tflow(resp=netutils.tresp()) cmd = example_dir.path(cmd) + " " + args
with example("add_header.py") as ex: m = RaiseMaster(options.Options(), None, state.State())
ex.run("response", flow) sc = script.Script(cmd)
assert flow.response.headers["newheader"] == "foo" m.addons.add(sc)
return m, sc
@mock.patch("mitmproxy.contentviews.remove") class TestScripts(mastertest.MasterTest):
@mock.patch("mitmproxy.contentviews.add") def test_add_header(self):
def test_custom_contentviews(add, remove): m, _ = tscript("add_header.py")
with example("custom_contentviews.py"): f = tutils.tflow(resp=netutils.tresp())
assert add.called self.invoke(m, "response", f)
pig = add.call_args[0][0] assert f.response.headers["newheader"] == "foo"
def test_custom_contentviews(self):
m, sc = tscript("custom_contentviews.py")
pig = contentviews.get("pig_latin_HTML")
_, fmt = pig(b"<html>test!</html>") _, fmt = pig(b"<html>test!</html>")
assert any(b'esttay!' in val[0][1] for val in fmt) assert any(b'esttay!' in val[0][1] for val in fmt)
assert not pig(b"gobbledygook") assert not pig(b"gobbledygook")
assert remove.called
def test_iframe_injector(self):
with tutils.raises(ScriptError):
tscript("iframe_injector.py")
def test_iframe_injector(): m, sc = tscript("iframe_injector.py", "http://example.org/evil_iframe")
with tutils.raises(script.ScriptException): flow = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>"))
with example("iframe_injector.py"): self.invoke(m, "response", flow)
pass
flow = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>"))
with example("iframe_injector.py http://example.org/evil_iframe") as ex:
ex.run("response", flow)
content = flow.response.content content = flow.response.content
assert b'iframe' in content and b'evil_iframe' in content assert b'iframe' in content and b'evil_iframe' in content
def test_modify_form(self):
m, sc = tscript("modify_form.py")
def test_modify_form(): form_header = Headers(content_type="application/x-www-form-urlencoded")
form_header = Headers(content_type="application/x-www-form-urlencoded") f = tutils.tflow(req=netutils.treq(headers=form_header))
flow = tutils.tflow(req=netutils.treq(headers=form_header)) self.invoke(m, "request", f)
with example("modify_form.py") as ex:
ex.run("request", flow)
assert flow.request.urlencoded_form[b"mitmproxy"] == b"rocks"
flow.request.headers["content-type"] = "" assert f.request.urlencoded_form[b"mitmproxy"] == b"rocks"
ex.run("request", flow)
assert list(flow.request.urlencoded_form.items()) == [(b"foo", b"bar")]
f.request.headers["content-type"] = ""
self.invoke(m, "request", f)
assert list(f.request.urlencoded_form.items()) == [(b"foo", b"bar")]
def test_modify_querystring(): def test_modify_querystring(self):
flow = tutils.tflow(req=netutils.treq(path=b"/search?q=term")) m, sc = tscript("modify_querystring.py")
with example("modify_querystring.py") as ex: f = tutils.tflow(req=netutils.treq(path="/search?q=term"))
ex.run("request", flow)
assert flow.request.query["mitmproxy"] == "rocks"
flow.request.path = "/" self.invoke(m, "request", f)
ex.run("request", flow) assert f.request.query["mitmproxy"] == "rocks"
assert flow.request.query["mitmproxy"] == "rocks"
f.request.path = "/"
self.invoke(m, "request", f)
assert f.request.query["mitmproxy"] == "rocks"
def test_modify_response_body(): def test_modify_response_body(self):
with tutils.raises(script.ScriptException): with tutils.raises(ScriptError):
with example("modify_response_body.py"): tscript("modify_response_body.py")
assert True
flow = tutils.tflow(resp=netutils.tresp(content=b"I <3 mitmproxy")) m, sc = tscript("modify_response_body.py", "mitmproxy rocks")
with example("modify_response_body.py mitmproxy rocks") as ex: f = tutils.tflow(resp=netutils.tresp(content=b"I <3 mitmproxy"))
assert ex.ns["state"]["old"] == b"mitmproxy" and ex.ns["state"]["new"] == b"rocks" self.invoke(m, "response", f)
ex.run("response", flow) assert f.response.content == b"I <3 rocks"
assert flow.response.content == b"I <3 rocks"
def test_redirect_requests(self):
m, sc = tscript("redirect_requests.py")
f = tutils.tflow(req=netutils.treq(host="example.org"))
self.invoke(m, "request", f)
assert f.request.host == "mitmproxy.org"
def test_redirect_requests(): def test_har_extractor(self):
flow = tutils.tflow(req=netutils.treq(host=b"example.org")) if sys.version_info >= (3, 0):
with example("redirect_requests.py") as ex: with tutils.raises("does not work on Python 3"):
ex.run("request", flow) tscript("har_extractor.py")
assert flow.request.host == "mitmproxy.org" return
with tutils.raises(ScriptError):
tscript("har_extractor.py")
@mock.patch("mitmproxy.ctx.log") with tutils.tmpdir() as tdir:
def test_har_extractor(log): times = dict(
if sys.version_info >= (3, 0): timestamp_start=746203272,
with tutils.raises("does not work on Python 3"): timestamp_end=746203272,
with example("har_extractor.py -"): )
pass
return
with tutils.raises(script.ScriptException): path = os.path.join(tdir, "file")
with example("har_extractor.py"): m, sc = tscript("har_extractor.py", six.moves.shlex_quote(path))
pass f = tutils.tflow(
req=netutils.treq(**times),
resp=netutils.tresp(**times)
)
self.invoke(m, "response", f)
m.addons.remove(sc)
times = dict( with open(path, "rb") as f:
timestamp_start=746203272, test_data = json.load(f)
timestamp_end=746203272, assert len(test_data["log"]["pages"]) == 1
)
flow = tutils.tflow(
req=netutils.treq(**times),
resp=netutils.tresp(**times)
)
with example("har_extractor.py -") as ex:
ex.run("response", flow)
with open(tutils.test_data.path("data/har_extractor.har")) as fp:
test_data = json.load(fp)
assert json.loads(ex.ns["context"].HARLog.json()) == test_data["test_response"]

View File

@ -5,7 +5,7 @@ import netlib.utils
from netlib.http import Headers from netlib.http import Headers
from mitmproxy import filt, controller, flow from mitmproxy import filt, controller, flow
from mitmproxy.contrib import tnetstring from mitmproxy.contrib import tnetstring
from mitmproxy.exceptions import FlowReadException, ScriptException from mitmproxy.exceptions import FlowReadException
from mitmproxy.models import Error from mitmproxy.models import Error
from mitmproxy.models import Flow from mitmproxy.models import Flow
from mitmproxy.models import HTTPFlow from mitmproxy.models import HTTPFlow
@ -674,21 +674,6 @@ class TestSerialize:
class TestFlowMaster: class TestFlowMaster:
def test_load_script(self):
s = flow.State()
fm = flow.FlowMaster(None, None, s)
fm.load_script(tutils.test_data.path("data/scripts/a.py"))
fm.load_script(tutils.test_data.path("data/scripts/a.py"))
fm.unload_scripts()
with tutils.raises(ScriptException):
fm.load_script("nonexistent")
try:
fm.load_script(tutils.test_data.path("data/scripts/starterr.py"))
except ScriptException as e:
assert "ValueError" in str(e)
assert len(fm.scripts) == 0
def test_getset_ignore(self): def test_getset_ignore(self):
p = mock.Mock() p = mock.Mock()
p.config.check_ignore = HostMatcher() p.config.check_ignore = HostMatcher()
@ -708,51 +693,7 @@ class TestFlowMaster:
assert "intercepting" in fm.replay_request(f) assert "intercepting" in fm.replay_request(f)
f.live = True f.live = True
assert "live" in fm.replay_request(f, run_scripthooks=True) assert "live" in fm.replay_request(f)
def test_script_reqerr(self):
s = flow.State()
fm = flow.FlowMaster(None, None, s)
fm.load_script(tutils.test_data.path("data/scripts/reqerr.py"))
f = tutils.tflow()
fm.clientconnect(f.client_conn)
assert fm.request(f)
def test_script(self):
s = flow.State()
fm = flow.FlowMaster(None, None, s)
fm.load_script(tutils.test_data.path("data/scripts/all.py"))
f = tutils.tflow(resp=True)
f.client_conn.acked = False
fm.clientconnect(f.client_conn)
assert fm.scripts[0].ns["log"][-1] == "clientconnect"
f.server_conn.acked = False
fm.serverconnect(f.server_conn)
assert fm.scripts[0].ns["log"][-1] == "serverconnect"
f.reply.acked = False
fm.request(f)
assert fm.scripts[0].ns["log"][-1] == "request"
f.reply.acked = False
fm.response(f)
assert fm.scripts[0].ns["log"][-1] == "response"
# load second script
fm.load_script(tutils.test_data.path("data/scripts/all.py"))
assert len(fm.scripts) == 2
f.server_conn.reply.acked = False
fm.clientdisconnect(f.server_conn)
assert fm.scripts[0].ns["log"][-1] == "clientdisconnect"
assert fm.scripts[1].ns["log"][-1] == "clientdisconnect"
# unload first script
fm.unload_scripts()
assert len(fm.scripts) == 0
fm.load_script(tutils.test_data.path("data/scripts/all.py"))
f.error = tutils.terr()
f.reply.acked = False
fm.error(f)
assert fm.scripts[0].ns["log"][-1] == "error"
def test_duplicate_flow(self): def test_duplicate_flow(self):
s = flow.State() s = flow.State()
@ -789,7 +730,6 @@ class TestFlowMaster:
f.error.reply = controller.DummyReply() f.error.reply = controller.DummyReply()
fm.error(f) fm.error(f)
fm.load_script(tutils.test_data.path("data/scripts/a.py"))
fm.shutdown() fm.shutdown()
def test_client_playback(self): def test_client_playback(self):

View File

@ -1,13 +0,0 @@
from mitmproxy import flow
from . import tutils
def test_duplicate_flow():
s = flow.State()
fm = flow.FlowMaster(None, None, s)
fm.load_script(tutils.test_data.path("data/scripts/duplicate_flow.py"))
f = tutils.tflow()
fm.request(f)
assert fm.state.flow_count() == 2
assert not fm.state.view[0].request.is_replay
assert fm.state.view[1].request.is_replay

View File

@ -13,6 +13,7 @@ from netlib.http import authentication, http1
from netlib.tutils import raises from netlib.tutils import raises
from pathod import pathoc, pathod from pathod import pathoc, pathod
from mitmproxy.builtins import script
from mitmproxy import controller from mitmproxy import controller
from mitmproxy.proxy.config import HostMatcher from mitmproxy.proxy.config import HostMatcher
from mitmproxy.models import Error, HTTPResponse, HTTPFlow from mitmproxy.models import Error, HTTPResponse, HTTPFlow
@ -287,10 +288,13 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):
self.master.set_stream_large_bodies(None) self.master.set_stream_large_bodies(None)
def test_stream_modify(self): def test_stream_modify(self):
self.master.load_script(tutils.test_data.path("data/scripts/stream_modify.py")) s = script.Script(
tutils.test_data.path("data/addonscripts/stream_modify.py")
)
self.master.addons.add(s)
d = self.pathod('200:b"foo"') d = self.pathod('200:b"foo"')
assert d.content == b"bar" assert d.content == b"bar"
self.master.unload_scripts() self.master.addons.remove(s)
class TestHTTPAuth(tservers.HTTPProxyTest): class TestHTTPAuth(tservers.HTTPProxyTest):
@ -512,15 +516,15 @@ class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin):
ssl = False ssl = False
def test_tcp_stream_modify(self): def test_tcp_stream_modify(self):
self.master.load_script(tutils.test_data.path("data/scripts/tcp_stream_modify.py")) s = script.Script(
tutils.test_data.path("data/addonscripts/tcp_stream_modify.py")
)
self.master.addons.add(s)
self._tcpproxy_on() self._tcpproxy_on()
d = self.pathod('200:b"foo"') d = self.pathod('200:b"foo"')
self._tcpproxy_off() self._tcpproxy_off()
assert d.content == b"bar" assert d.content == b"bar"
self.master.addons.remove(s)
self.master.unload_scripts()
class TestTransparentSSL(tservers.TransparentProxyTest, CommonMixin, TcpMixin): class TestTransparentSSL(tservers.TransparentProxyTest, CommonMixin, TcpMixin):