improve mitmproxy.scripts semantics, clean up tests

This commit is contained in:
Maximilian Hils 2016-03-18 18:51:54 +01:00
parent 36fb8a32f4
commit 898f5d10b9
8 changed files with 119 additions and 72 deletions

View File

@ -318,6 +318,7 @@ class ConsoleMaster(flow.FlowMaster):
try: try:
s = script.Script(command, script.ScriptContext(self)) s = script.Script(command, script.ScriptContext(self))
s.load()
except script.ScriptException as v: except script.ScriptException as v:
signals.status_message.send( signals.status_message.send(
message = "Error loading script." message = "Error loading script."

View File

@ -685,6 +685,7 @@ class FlowMaster(controller.Master):
""" """
try: try:
s = script.Script(command, script.ScriptContext(self)) s = script.Script(command, script.ScriptContext(self))
s.load()
except script.ScriptException as v: except script.ScriptException as v:
return v.args[0] return v.args[0]
if use_reloader: if use_reloader:

View File

@ -22,7 +22,15 @@ class Script(object):
self.args = self.parse_command(command) self.args = self.parse_command(command)
self.ctx = context self.ctx = context
self.ns = None self.ns = None
def __enter__(self):
self.load() self.load()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val:
return False # reraise the exception
self.unload()
@property @property
def filename(self): def filename(self):
@ -35,7 +43,7 @@ class Script(object):
if os.name == "nt": # Windows: escape all backslashes in the path. if os.name == "nt": # Windows: escape all backslashes in the path.
backslashes = shlex.split(command, posix=False)[0].count("\\") backslashes = shlex.split(command, posix=False)[0].count("\\")
command = command.replace("\\", "\\\\", backslashes) command = command.replace("\\", "\\\\", backslashes)
args = shlex.split(command) args = shlex.split(command) # pragma: nocover
args[0] = os.path.expanduser(args[0]) args[0] = os.path.expanduser(args[0])
if not os.path.exists(args[0]): if not os.path.exists(args[0]):
raise ScriptException( raise ScriptException(
@ -58,7 +66,7 @@ class Script(object):
ScriptException on failure ScriptException on failure
""" """
if self.ns is not None: if self.ns is not None:
self.unload() raise ScriptException("Script is already loaded")
script_dir = os.path.dirname(os.path.abspath(self.args[0])) script_dir = os.path.dirname(os.path.abspath(self.args[0]))
self.ns = {'__file__': os.path.abspath(self.args[0])} self.ns = {'__file__': os.path.abspath(self.args[0])}
sys.path.append(script_dir) sys.path.append(script_dir)

View File

View File

@ -0,0 +1,83 @@
from mitmproxy.script import Script
from mitmproxy.exceptions import ScriptException
from test.mitmproxy import tutils
class TestParseCommand:
def test_empty_command(self):
with tutils.raises(ScriptException):
Script.parse_command("")
with tutils.raises(ScriptException):
Script.parse_command(" ")
def test_no_script_file(self):
with tutils.raises("not found"):
Script.parse_command("notfound")
with tutils.tmpdir() as dir:
with tutils.raises("not a file"):
Script.parse_command(dir)
def test_parse_args(self):
with tutils.chdir(tutils.test_data.dirname):
assert Script.parse_command("scripts/a.py") == ["scripts/a.py"]
assert Script.parse_command("scripts/a.py foo bar") == ["scripts/a.py", "foo", "bar"]
assert Script.parse_command("scripts/a.py 'foo bar'") == ["scripts/a.py", "foo bar"]
@tutils.skip_not_windows
def test_parse_windows(self):
with tutils.chdir(tutils.test_data.dirname):
assert Script.parse_command("scripts\\a.py") == ["scripts\\a.py"]
assert Script.parse_command("scripts\\a.py 'foo \\ bar'") == ["scripts\\a.py", 'foo \\ bar']
def test_simple():
with tutils.chdir(tutils.test_data.path("scripts")):
s = Script("a.py --var 42", None)
assert s.filename == "a.py"
assert s.ns is None
s.load()
assert s.ns["var"] == 42
s.run("here")
assert s.ns["var"] == 43
s.unload()
assert s.ns is None
with tutils.raises(ScriptException):
s.run("here")
with Script("a.py --var 42", None) as s:
s.run("here")
def test_script_exception():
with tutils.chdir(tutils.test_data.path("scripts")):
s = Script("syntaxerr.py", None)
with tutils.raises(ScriptException):
s.load()
s = Script("starterr.py", None)
with tutils.raises(ScriptException):
s.load()
s = Script("a.py", None)
s.load()
with tutils.raises(ScriptException):
s.load()
s = Script("a.py", None)
with tutils.raises(ScriptException):
s.run("here")
with tutils.raises(ScriptException):
with Script("reqerr.py", None) as s:
s.run("request", None)
s = Script("unloaderr.py", None)
s.load()
with tutils.raises(ScriptException):
s.unload()

View File

@ -31,9 +31,8 @@ class DummyContext(object):
def example(command): def example(command):
command = os.path.join(example_dir, command) command = os.path.join(example_dir, command)
ctx = DummyContext() ctx = DummyContext()
s = script.Script(command, ctx) with script.Script(command, ctx) as s:
yield s yield s
s.unload()
def test_load_scripts(): def test_load_scripts():
@ -52,8 +51,10 @@ def test_load_scripts():
f += " ~a" f += " ~a"
if "modify_response_body" in f: if "modify_response_body" in f:
f += " foo bar" # two arguments required f += " foo bar" # two arguments required
s = script.Script(f, script.ScriptContext(tmaster))
try: try:
s = script.Script(f, script.ScriptContext(tmaster)) # Loads the script file. s.load()
except Exception as v: except Exception as v:
if "ImportError" not in str(v): if "ImportError" not in str(v):
raise raise

View File

@ -1,27 +1,9 @@
import os
import time import time
import mock import mock
from mitmproxy import script, flow from mitmproxy import script, flow
from . import tutils from . import tutils
def test_simple():
s = flow.State()
fm = flow.FlowMaster(None, s)
sp = tutils.test_data.path("scripts/a.py")
p = script.Script("%s --var 40" % sp, script.ScriptContext(fm))
assert "here" in p.ns
assert p.run("here") == 41
assert p.run("here") == 42
tutils.raises(script.ScriptException, p.run, "errargs")
# Check reload
p.load()
assert p.run("here") == 41
def test_duplicate_flow(): def test_duplicate_flow():
s = flow.State() s = flow.State()
fm = flow.FlowMaster(None, s) fm = flow.FlowMaster(None, s)
@ -33,35 +15,6 @@ def test_duplicate_flow():
assert fm.state.view[1].request.is_replay assert fm.state.view[1].request.is_replay
def test_err():
s = flow.State()
fm = flow.FlowMaster(None, s)
sc = script.ScriptContext(fm)
tutils.raises(
"not found",
script.Script, "nonexistent", sc
)
tutils.raises(
"not a file",
script.Script, tutils.test_data.path("scripts"), sc
)
tutils.raises(
script.ScriptException,
script.Script, tutils.test_data.path("scripts/syntaxerr.py"), sc
)
tutils.raises(
script.ScriptException,
script.Script, tutils.test_data.path("scripts/loaderr.py"), sc
)
scr = script.Script(tutils.test_data.path("scripts/unloaderr.py"), sc)
tutils.raises(script.ScriptException, scr.unload)
@tutils.skip_appveyor @tutils.skip_appveyor
def test_concurrent(): def test_concurrent():
s = flow.State() s = flow.State()
@ -117,16 +70,6 @@ def test_concurrent2():
def test_concurrent_err(): def test_concurrent_err():
s = flow.State() s = flow.State()
fm = flow.FlowMaster(None, s) fm = flow.FlowMaster(None, s)
tutils.raises( with tutils.raises("Concurrent decorator not supported for 'start' method"):
"Concurrent decorator not supported for 'start' method", s = script.Script(tutils.test_data.path("scripts/concurrent_decorator_err.py"), fm)
script.Script, s.load()
tutils.test_data.path("scripts/concurrent_decorator_err.py"),
fm)
def test_command_parsing():
s = flow.State()
fm = flow.FlowMaster(None, s)
absfilepath = os.path.normcase(tutils.test_data.path("scripts/a.py"))
s = script.Script(absfilepath, script.ScriptContext(fm))
assert os.path.isfile(s.args[0])

View File

@ -26,6 +26,13 @@ def skip_windows(fn):
return fn return fn
def skip_not_windows(fn):
if os.name == "nt":
return fn
else:
return _skip_windows
def _skip_appveyor(*args): def _skip_appveyor(*args):
raise SkipTest("Skipped on AppVeyor.") raise SkipTest("Skipped on AppVeyor.")
@ -119,15 +126,18 @@ def get_body_line(last_displayed_body, line_nb):
return last_displayed_body.contents()[line_nb + 2] return last_displayed_body.contents()[line_nb + 2]
@contextmanager
def chdir(dir):
orig_dir = os.getcwd()
os.chdir(dir)
yield
os.chdir(orig_dir)
@contextmanager @contextmanager
def tmpdir(*args, **kwargs): def tmpdir(*args, **kwargs):
orig_workdir = os.getcwd()
temp_workdir = tempfile.mkdtemp(*args, **kwargs) temp_workdir = tempfile.mkdtemp(*args, **kwargs)
os.chdir(temp_workdir) with chdir(temp_workdir):
yield temp_workdir yield temp_workdir
os.chdir(orig_workdir)
shutil.rmtree(temp_workdir) shutil.rmtree(temp_workdir)