Script cleanups

- Preserve script order on config change
- Prohibit script duplicates (i.e. identical script + args)
- Various cleanups and tweaks
This commit is contained in:
Aldo Cortesi 2016-07-15 16:35:24 +12:00
parent 917d51bd22
commit c7d0850d8f
3 changed files with 100 additions and 18 deletions

View File

@ -99,23 +99,25 @@ class Script:
self.path, self.args = parse_command(command) self.path, self.args = parse_command(command)
self.ns = None self.ns = None
self.observer = None self.observer = None
self.dead = False
self.last_options = None self.last_options = None
self.should_reload = threading.Event() self.should_reload = threading.Event()
for i in controller.Events - set(["start", "configure", "tick"]): for i in controller.Events:
def mkprox(): if not hasattr(self, i):
evt = i def mkprox():
evt = i
def prox(*args, **kwargs): def prox(*args, **kwargs):
self.run(evt, *args, **kwargs) self.run(evt, *args, **kwargs)
return prox return prox
setattr(self, i, mkprox()) setattr(self, i, mkprox())
def run(self, name, *args, **kwargs): def run(self, name, *args, **kwargs):
# It's possible for ns to be un-initialised if we failed during # It's possible for ns to be un-initialised if we failed during
# configure # configure
if self.ns is not None: if self.ns is not None and not self.dead:
func = self.ns.get(name) func = self.ns.get(name)
if func: if func:
with scriptenv(self.path, self.args): with scriptenv(self.path, self.args):
@ -149,18 +151,35 @@ class Script:
self.observer.start() self.observer.start()
self.run("configure", options) self.run("configure", options)
def done(self):
self.run("done")
self.dead = True
class ScriptLoader(): class ScriptLoader():
""" """
An addon that manages loading scripts from options. An addon that manages loading scripts from options.
""" """
def configure(self, options): def configure(self, options):
for s in options.scripts or []: for s in options.scripts:
if not ctx.master.addons.has_addon(s): if options.scripts.count(s) > 1:
raise exceptions.OptionsError("Duplicate script: %s" % s)
for a in ctx.master.addons.chain[:]:
if isinstance(a, Script) and a.name not in options.scripts:
ctx.log.info("Un-loading script: %s" % a.name)
ctx.master.addons.remove(a)
current = {}
for a in ctx.master.addons.chain[:]:
if isinstance(a, Script):
current[a.name] = a
ctx.master.addons.chain.remove(a)
for s in options.scripts:
if s in current:
ctx.master.addons.chain.append(current[s])
else:
ctx.log.info("Loading script: %s" % s) ctx.log.info("Loading script: %s" % s)
sc = Script(s) sc = Script(s)
ctx.master.addons.add(sc) ctx.master.addons.add(sc)
for a in ctx.master.addons.chain:
if isinstance(a, Script):
if a.name not in options.scripts or []:
ctx.master.addons.remove(a)

View File

@ -58,8 +58,8 @@ class TestScript(mastertest.MasterTest):
) )
m.addons.add(sc) m.addons.add(sc)
assert sc.ns["call_log"] == [ assert sc.ns["call_log"] == [
("start", (), {}), ("solo", "start", (), {}),
("configure", (options.Options(),), {}) ("solo", "configure", (options.Options(),), {})
] ]
sc.ns["call_log"] = [] sc.ns["call_log"] = []
@ -67,7 +67,7 @@ class TestScript(mastertest.MasterTest):
self.invoke(m, "request", f) self.invoke(m, "request", f)
recf = sc.ns["call_log"][0] recf = sc.ns["call_log"][0]
assert recf[0] == "request" assert recf[1] == "request"
def test_reload(self): def test_reload(self):
s = state.State() s = state.State()
@ -129,3 +129,59 @@ class TestScriptLoader(mastertest.MasterTest):
assert len(m.addons) == 2 assert len(m.addons) == 2
o.update(scripts = []) o.update(scripts = [])
assert len(m.addons) == 1 assert len(m.addons) == 1
def test_dupes(self):
s = state.State()
o = options.Options(scripts=["one", "one"])
m = master.FlowMaster(o, None, s)
sc = script.ScriptLoader()
tutils.raises(exceptions.OptionsError, m.addons.add, sc)
def test_order(self):
rec = tutils.test_data.path("data/addonscripts/recorder.py")
s = state.State()
o = options.Options(
scripts = [
"%s %s" % (rec, "a"),
"%s %s" % (rec, "b"),
"%s %s" % (rec, "c"),
]
)
m = mastertest.RecordingMaster(o, None, s)
sc = script.ScriptLoader()
m.addons.add(sc)
debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"]
assert debug == [
('debug', 'a start'), ('debug', 'a configure'),
('debug', 'b start'), ('debug', 'b configure'),
('debug', 'c start'), ('debug', 'c configure')
]
m.event_log[:] = []
o.scripts = [
"%s %s" % (rec, "c"),
"%s %s" % (rec, "a"),
"%s %s" % (rec, "b"),
]
debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"]
assert debug == [
('debug', 'c configure'),
('debug', 'a configure'),
('debug', 'b configure'),
]
m.event_log[:] = []
o.scripts = [
"%s %s" % (rec, "x"),
"%s %s" % (rec, "a"),
]
debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"]
assert debug == [
('debug', 'c done'),
('debug', 'b done'),
('debug', 'x start'),
('debug', 'x configure'),
('debug', 'a configure'),
]

View File

@ -1,8 +1,14 @@
from mitmproxy import controller from mitmproxy import controller
from mitmproxy import ctx from mitmproxy import ctx
import sys
call_log = [] call_log = []
if len(sys.argv) > 1:
name = sys.argv[1]
else:
name = "solo"
# Keep a log of all possible event calls # Keep a log of all possible event calls
evts = list(controller.Events) + ["configure"] evts = list(controller.Events) + ["configure"]
for i in evts: for i in evts:
@ -10,9 +16,10 @@ for i in evts:
evt = i evt = i
def prox(*args, **kwargs): def prox(*args, **kwargs):
lg = (evt, args, kwargs) lg = (name, evt, args, kwargs)
if evt != "log": if evt != "log":
ctx.log.info(str(lg)) ctx.log.info(str(lg))
call_log.append(lg) call_log.append(lg)
ctx.log.debug("%s %s" % (name, evt))
return prox return prox
globals()[i] = mkprox() globals()[i] = mkprox()