asyncio: shift script reloading out of the tick event

The tick event is a nasty compromise, left over from when we didn't have an
event loop. This is the first patch in a series that explores moving our
built-in addons to managing coroutines on the eventloop directly for periodic
tasks.
This commit is contained in:
Aldo Cortesi 2018-04-07 11:46:34 +12:00
parent f6b606b364
commit 44016a0de5
13 changed files with 133 additions and 105 deletions

View File

@ -140,6 +140,7 @@ class AddonManager:
""" """
for i in self.chain: for i in self.chain:
self.remove(i) self.remove(i)
self.lookup = {}
def get(self, name): def get(self, name):
""" """

View File

@ -1,7 +1,7 @@
import asyncio
import os import os
import importlib.util import importlib.util
import importlib.machinery import importlib.machinery
import time
import sys import sys
import types import types
import typing import typing
@ -59,13 +59,15 @@ def script_error_handler(path, exc, msg="", tb=False):
ctx.log.error(log_msg) ctx.log.error(log_msg)
ReloadInterval = 1
class Script: class Script:
""" """
An addon that manages a single script. An addon that manages a single script.
""" """
ReloadInterval = 2
def __init__(self, path): def __init__(self, path: str, reload: bool) -> None:
self.name = "scriptmanager:" + path self.name = "scriptmanager:" + path
self.path = path self.path = path
self.fullpath = os.path.expanduser( self.fullpath = os.path.expanduser(
@ -73,45 +75,57 @@ class Script:
) )
self.ns = None self.ns = None
self.last_load = 0
self.last_mtime = 0
if not os.path.isfile(self.fullpath): if not os.path.isfile(self.fullpath):
raise exceptions.OptionsError('No such script') raise exceptions.OptionsError('No such script')
self.reloadtask = None
if reload:
self.reloadtask = asyncio.ensure_future(self.watcher())
else:
self.loadscript()
def done(self):
if self.reloadtask:
self.reloadtask.cancel()
@property @property
def addons(self): def addons(self):
return [self.ns] if self.ns else [] return [self.ns] if self.ns else []
def tick(self): def loadscript(self):
if time.time() - self.last_load > self.ReloadInterval: ctx.log.info("Loading script %s" % self.path)
if self.ns:
ctx.master.addons.remove(self.ns)
self.ns = None
with addonmanager.safecall():
ns = load_script(self.fullpath)
ctx.master.addons.register(ns)
self.ns = ns
if self.ns:
# We're already running, so we have to explicitly register and
# configure the addon
ctx.master.addons.invoke_addon(self.ns, "running")
ctx.master.addons.invoke_addon(
self.ns,
"configure",
ctx.options.keys()
)
async def watcher(self):
last_mtime = 0
while True:
try: try:
mtime = os.stat(self.fullpath).st_mtime mtime = os.stat(self.fullpath).st_mtime
except FileNotFoundError: except FileNotFoundError:
ctx.log.info("Removing script %s" % self.path)
scripts = list(ctx.options.scripts) scripts = list(ctx.options.scripts)
scripts.remove(self.path) scripts.remove(self.path)
ctx.options.update(scripts=scripts) ctx.options.update(scripts=scripts)
return return
if mtime > last_mtime:
if mtime > self.last_mtime: self.loadscript()
ctx.log.info("Loading script: %s" % self.path) last_mtime = mtime
if self.ns: await asyncio.sleep(ReloadInterval)
ctx.master.addons.remove(self.ns)
self.ns = None
with addonmanager.safecall():
ns = load_script(self.fullpath)
ctx.master.addons.register(ns)
self.ns = ns
if self.ns:
# We're already running, so we have to explicitly register and
# configure the addon
ctx.master.addons.invoke_addon(self.ns, "running")
ctx.master.addons.invoke_addon(
self.ns,
"configure",
ctx.options.keys()
)
self.last_load = time.time()
self.last_mtime = mtime
class ScriptLoader: class ScriptLoader:
@ -125,9 +139,7 @@ class ScriptLoader:
def load(self, loader): def load(self, loader):
loader.add_option( loader.add_option(
"scripts", typing.Sequence[str], [], "scripts", typing.Sequence[str], [],
""" "Execute a script."
Execute a script.
"""
) )
def running(self): def running(self):
@ -141,12 +153,7 @@ class ScriptLoader:
simulated. simulated.
""" """
try: try:
s = Script(path) s = Script(path, False)
l = addonmanager.Loader(ctx.master)
ctx.master.addons.invoke_addon(s, "load", l)
ctx.master.addons.invoke_addon(s, "configure", ctx.options.keys())
# Script is loaded on the first tick
ctx.master.addons.invoke_addon(s, "tick")
for f in flows: for f in flows:
for evt, arg in eventsequence.iterate(f): for evt, arg in eventsequence.iterate(f):
ctx.master.addons.invoke_addon(s, evt, arg) ctx.master.addons.invoke_addon(s, evt, arg)
@ -161,7 +168,7 @@ class ScriptLoader:
for a in self.addons[:]: for a in self.addons[:]:
if a.path not in ctx.options.scripts: if a.path not in ctx.options.scripts:
ctx.log.info("Un-loading script: %s" % a.name) ctx.log.info("Un-loading script: %s" % a.path)
ctx.master.addons.remove(a) ctx.master.addons.remove(a)
self.addons.remove(a) self.addons.remove(a)
@ -181,7 +188,7 @@ class ScriptLoader:
if s in current: if s in current:
ordered.append(current[s]) ordered.append(current[s])
else: else:
sc = Script(s) sc = Script(s, True)
ordered.append(sc) ordered.append(sc)
newscripts.append(sc) newscripts.append(sc)

View File

@ -94,19 +94,14 @@ class Master:
exc = None exc = None
try: try:
loop() loop()
except Exception as e: except Exception as e: # pragma: no cover
exc = traceback.format_exc() exc = traceback.format_exc()
finally: finally:
if not self.should_exit.is_set(): if not self.should_exit.is_set(): # pragma: no cover
self.shutdown() self.shutdown()
pending = asyncio.Task.all_tasks()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: for p in asyncio.Task.all_tasks():
loop.run_until_complete(asyncio.gather(*pending)) p.cancel()
except Exception as e:
# When we exit with an error, shutdown might not happen cleanly,
# and we can get exceptions here caused by pending Futures.
pass
loop.close() loop.close()
if exc: # pragma: no cover if exc: # pragma: no cover
@ -122,6 +117,7 @@ class Master:
self.run_loop(loop.run_forever) self.run_loop(loop.run_forever)
async def _shutdown(self): async def _shutdown(self):
self.should_exit.set()
if self.server: if self.server:
self.server.shutdown() self.server.shutdown()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()

View File

@ -123,11 +123,7 @@ class context:
""" """
Loads a script from path, and returns the enclosed addon. Loads a script from path, and returns the enclosed addon.
""" """
sc = script.Script(path) sc = script.Script(path, False)
loader = addonmanager.Loader(self.master)
self.master.addons.invoke_addon(sc, "load", loader)
self.configure(sc)
self.master.addons.invoke_addon(sc, "tick")
return sc.addons[0] if sc.addons else None return sc.addons[0] if sc.addons else None
def invoke(self, addon, event, *args, **kwargs): def invoke(self, addon, event, *args, **kwargs):

View File

@ -4,23 +4,21 @@ from mitmproxy.addons import onboarding
from mitmproxy.test import taddons from mitmproxy.test import taddons
from .. import tservers from .. import tservers
import asyncio
import tornado.platform.asyncio
asyncio.set_event_loop_policy(tornado.platform.asyncio.AnyThreadEventLoopPolicy())
class TestApp(tservers.HTTPProxyTest): class TestApp(tservers.HTTPProxyTest):
def addons(self): def addons(self):
return [onboarding.Onboarding()] return [onboarding.Onboarding()]
def test_basic(self): @pytest.mark.asyncio
async def test_basic(self):
ob = onboarding.Onboarding() ob = onboarding.Onboarding()
with taddons.context(ob) as tctx: with taddons.context(ob) as tctx:
tctx.configure(ob) tctx.configure(ob)
assert self.app("/").status_code == 200 assert self.app("/").status_code == 200
@pytest.mark.parametrize("ext", ["pem", "p12"]) @pytest.mark.parametrize("ext", ["pem", "p12"])
def test_cert(self, ext): @pytest.mark.asyncio
async def test_cert(self, ext):
ob = onboarding.Onboarding() ob = onboarding.Onboarding()
with taddons.context(ob) as tctx: with taddons.context(ob) as tctx:
tctx.configure(ob) tctx.configure(ob)
@ -29,7 +27,8 @@ class TestApp(tservers.HTTPProxyTest):
assert resp.content assert resp.content
@pytest.mark.parametrize("ext", ["pem", "p12"]) @pytest.mark.parametrize("ext", ["pem", "p12"])
def test_head(self, ext): @pytest.mark.asyncio
async def test_head(self, ext):
ob = onboarding.Onboarding() ob = onboarding.Onboarding()
with taddons.context(ob) as tctx: with taddons.context(ob) as tctx:
tctx.configure(ob) tctx.configure(ob)

View File

@ -12,6 +12,10 @@ from mitmproxy.test import tflow
from mitmproxy.test import tutils from mitmproxy.test import tutils
# We want this to be speedy for testing
script.ReloadInterval = 0.1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_script(): async def test_load_script():
with taddons.context() as tctx: with taddons.context() as tctx:
@ -71,7 +75,7 @@ class TestScript:
def test_notfound(self): def test_notfound(self):
with taddons.context(): with taddons.context():
with pytest.raises(exceptions.OptionsError): with pytest.raises(exceptions.OptionsError):
script.Script("nonexistent") script.Script("nonexistent", False)
def test_quotes_around_filename(self): def test_quotes_around_filename(self):
""" """
@ -81,21 +85,23 @@ class TestScript:
path = tutils.test_data.path("mitmproxy/data/addonscripts/recorder/recorder.py") path = tutils.test_data.path("mitmproxy/data/addonscripts/recorder/recorder.py")
s = script.Script( s = script.Script(
'"{}"'.format(path) '"{}"'.format(path),
False
) )
assert '"' not in s.fullpath assert '"' not in s.fullpath
def test_simple(self): @pytest.mark.asyncio
async def test_simple(self):
with taddons.context() as tctx: with taddons.context() as tctx:
sc = script.Script( sc = script.Script(
tutils.test_data.path( tutils.test_data.path(
"mitmproxy/data/addonscripts/recorder/recorder.py" "mitmproxy/data/addonscripts/recorder/recorder.py"
) ),
True,
) )
tctx.master.addons.add(sc) tctx.master.addons.add(sc)
tctx.configure(sc) tctx.configure(sc)
sc.tick() await tctx.master.await_log("recorder running")
rec = tctx.master.addons.get("recorder") rec = tctx.master.addons.get("recorder")
assert rec.call_log[0][0:2] == ("recorder", "load") assert rec.call_log[0][0:2] == ("recorder", "load")
@ -112,25 +118,24 @@ class TestScript:
f = tmpdir.join("foo.py") f = tmpdir.join("foo.py")
f.ensure(file=True) f.ensure(file=True)
f.write("\n") f.write("\n")
sc = script.Script(str(f)) sc = script.Script(str(f), True)
tctx.configure(sc) tctx.configure(sc)
sc.tick()
assert await tctx.master.await_log("Loading") assert await tctx.master.await_log("Loading")
tctx.master.clear()
sc.last_load, sc.last_mtime = 0, 0 tctx.master.clear()
sc.tick() f.write("\n")
assert await tctx.master.await_log("Loading") assert await tctx.master.await_log("Loading")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exception(self): async def test_exception(self):
with taddons.context() as tctx: with taddons.context() as tctx:
sc = script.Script( sc = script.Script(
tutils.test_data.path("mitmproxy/data/addonscripts/error.py") tutils.test_data.path("mitmproxy/data/addonscripts/error.py"),
True,
) )
tctx.master.addons.add(sc) tctx.master.addons.add(sc)
await tctx.master.await_log("error running")
tctx.configure(sc) tctx.configure(sc)
sc.tick()
f = tflow.tflow(resp=True) f = tflow.tflow(resp=True)
tctx.master.addons.trigger("request", f) tctx.master.addons.trigger("request", f)
@ -138,16 +143,17 @@ class TestScript:
assert await tctx.master.await_log("ValueError: Error!") assert await tctx.master.await_log("ValueError: Error!")
assert await tctx.master.await_log("error.py") assert await tctx.master.await_log("error.py")
def test_addon(self): @pytest.mark.asyncio
async def test_addon(self):
with taddons.context() as tctx: with taddons.context() as tctx:
sc = script.Script( sc = script.Script(
tutils.test_data.path( tutils.test_data.path(
"mitmproxy/data/addonscripts/addon.py" "mitmproxy/data/addonscripts/addon.py"
) ),
True
) )
tctx.master.addons.add(sc) tctx.master.addons.add(sc)
tctx.configure(sc) await tctx.master.await_log("addon running")
sc.tick()
assert sc.ns.event_log == [ assert sc.ns.event_log == [
'scriptload', 'addonload', 'scriptconfigure', 'addonconfigure' 'scriptload', 'addonload', 'scriptconfigure', 'addonconfigure'
] ]
@ -184,7 +190,6 @@ class TestScriptLoader:
debug = [i.msg for i in tctx.master.logs if i.level == "debug"] debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [ assert debug == [
'recorder load', 'recorder running', 'recorder configure', 'recorder load', 'recorder running', 'recorder configure',
'recorder tick',
'recorder requestheaders', 'recorder request', 'recorder requestheaders', 'recorder request',
'recorder responseheaders', 'recorder response' 'recorder responseheaders', 'recorder response'
] ]
@ -224,17 +229,21 @@ class TestScriptLoader:
scripts = ["one", "one"] scripts = ["one", "one"]
) )
def test_script_deletion(self): @pytest.mark.asyncio
async def test_script_deletion(self):
tdir = tutils.test_data.path("mitmproxy/data/addonscripts/") tdir = tutils.test_data.path("mitmproxy/data/addonscripts/")
with open(tdir + "/dummy.py", 'w') as f: with open(tdir + "/dummy.py", 'w') as f:
f.write("\n") f.write("\n")
with taddons.context() as tctx: with taddons.context() as tctx:
sl = script.ScriptLoader() sl = script.ScriptLoader()
tctx.master.addons.add(sl) tctx.master.addons.add(sl)
tctx.configure(sl, scripts=[tutils.test_data.path("mitmproxy/data/addonscripts/dummy.py")]) tctx.configure(sl, scripts=[tutils.test_data.path("mitmproxy/data/addonscripts/dummy.py")])
await tctx.master.await_log("Loading")
os.remove(tutils.test_data.path("mitmproxy/data/addonscripts/dummy.py")) os.remove(tutils.test_data.path("mitmproxy/data/addonscripts/dummy.py"))
tctx.invoke(sl, "tick")
await tctx.master.await_log("Removing")
assert not tctx.options.scripts assert not tctx.options.scripts
assert not sl.addons assert not sl.addons
@ -286,17 +295,14 @@ class TestScriptLoader:
'a load', 'a load',
'a running', 'a running',
'a configure', 'a configure',
'a tick',
'b load', 'b load',
'b running', 'b running',
'b configure', 'b configure',
'b tick',
'c load', 'c load',
'c running', 'c running',
'c configure', 'c configure',
'c tick',
] ]
tctx.master.clear() tctx.master.clear()
@ -317,7 +323,7 @@ class TestScriptLoader:
'b configure', 'b configure',
] ]
tctx.master.logs = [] tctx.master.clear()
tctx.configure( tctx.configure(
sc, sc,
scripts = [ scripts = [
@ -325,9 +331,7 @@ class TestScriptLoader:
"%s/a.py" % rec, "%s/a.py" % rec,
] ]
) )
tctx.master.addons.invoke_addon(sc, "tick") await tctx.master.await_log("Loading")
await tctx.master.await_log("a tick")
debug = [i.msg for i in tctx.master.logs if i.level == "debug"] debug = [i.msg for i in tctx.master.logs if i.level == "debug"]
assert debug == [ assert debug == [
'c done', 'c done',
@ -336,6 +340,4 @@ class TestScriptLoader:
'e load', 'e load',
'e running', 'e running',
'e configure', 'e configure',
'e tick',
'a tick',
] ]

View File

@ -1,3 +1,4 @@
from mitmproxy import ctx
event_log = [] event_log = []
@ -7,6 +8,7 @@ class Addon:
return event_log return event_log
def load(self, opts): def load(self, opts):
ctx.log.info("addon running")
event_log.append("addonload") event_log.append("addonload")
def configure(self, updated): def configure(self, updated):

View File

@ -1,6 +1,9 @@
def mkerr(): from mitmproxy import ctx
raise ValueError("Error!")
def running():
ctx.log.info("error running")
def request(flow): def request(flow):
mkerr() raise ValueError("Error!")

View File

@ -0,0 +1,5 @@
from mitmproxy import ctx
def running():
ctx.master.shutdown()

View File

@ -1,7 +1,13 @@
from mitmproxy import ctx
def modify(chunks): def modify(chunks):
for chunk in chunks: for chunk in chunks:
yield chunk.replace(b"foo", b"bar") yield chunk.replace(b"foo", b"bar")
def running():
ctx.log.info("stream_modify running")
def responseheaders(flow): def responseheaders(flow):
flow.response.stream = modify flow.response.stream = modify

View File

@ -256,11 +256,15 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin):
resp = p.request("get:'http://foo':h':foo'='bar'") resp = p.request("get:'http://foo':h':foo'='bar'")
assert resp.status_code == 400 assert resp.status_code == 400
def test_stream_modify(self): @pytest.mark.asyncio
async def test_stream_modify(self):
s = script.Script( s = script.Script(
tutils.test_data.path("mitmproxy/data/addonscripts/stream_modify.py") tutils.test_data.path("mitmproxy/data/addonscripts/stream_modify.py"),
False,
) )
self.set_addons(s) self.set_addons(s)
await self.master.await_log("stream_modify running")
d = self.pathod('200:b"foo"') d = self.pathod('200:b"foo"')
assert d.content == b"bar" assert d.content == b"bar"
@ -564,7 +568,8 @@ class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin):
def test_tcp_stream_modify(self): def test_tcp_stream_modify(self):
s = script.Script( s = script.Script(
tutils.test_data.path("mitmproxy/data/addonscripts/tcp_stream_modify.py") tutils.test_data.path("mitmproxy/data/addonscripts/tcp_stream_modify.py"),
False,
) )
self.set_addons(s) self.set_addons(s)
self._tcpproxy_on() self._tcpproxy_on()

View File

@ -1,19 +1,23 @@
import pytest import asyncio
from mitmproxy.tools import main from mitmproxy.tools import main
from mitmproxy import ctx from mitmproxy.test import tutils
shutdown_script = tutils.test_data.path("mitmproxy/data/addonscripts/shutdown.py")
@pytest.mark.asyncio def test_mitmweb(event_loop):
async def test_mitmweb(event_loop): asyncio.set_event_loop(event_loop)
main.mitmweb([ main.mitmweb([
"--no-web-open-browser", "--no-web-open-browser",
"-s", shutdown_script,
"-q", "-p", "0", "-q", "-p", "0",
]) ])
await ctx.master._shutdown()
@pytest.mark.asyncio def test_mitmdump(event_loop):
async def test_mitmdump(): asyncio.set_event_loop(event_loop)
main.mitmdump(["-q", "-p", "0"]) main.mitmdump([
await ctx.master._shutdown() "-s", shutdown_script,
"-q", "-p", "0",
])

View File

@ -1,5 +1,6 @@
import json import json
from unittest import mock from unittest import mock
import pytest
from mitmproxy.test import taddons from mitmproxy.test import taddons
from mitmproxy.test import tflow from mitmproxy.test import tflow
@ -57,7 +58,8 @@ def test_save_flows_content(ctx, tmpdir):
assert p.join('response/content/Auto.json').check(file=1) assert p.join('response/content/Auto.json').check(file=1)
def test_static_viewer(tmpdir): @pytest.mark.asyncio
async def test_static_viewer(tmpdir):
s = static_viewer.StaticViewer() s = static_viewer.StaticViewer()
rf = readfile.ReadFile() rf = readfile.ReadFile()
sa = save.Save() sa = save.Save()