diff --git a/mitmproxy/console/__init__.py b/mitmproxy/console/__init__.py index 5669be48a..2f52d0b89 100644 --- a/mitmproxy/console/__init__.py +++ b/mitmproxy/console/__init__.py @@ -318,6 +318,7 @@ class ConsoleMaster(flow.FlowMaster): try: s = script.Script(command, script.ScriptContext(self)) + s.load() except script.ScriptException as v: signals.status_message.send( message = "Error loading script." diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 8fa84ed8e..7de9823da 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -685,6 +685,7 @@ class FlowMaster(controller.Master): """ try: s = script.Script(command, script.ScriptContext(self)) + s.load() except script.ScriptException as v: return v.args[0] if use_reloader: diff --git a/mitmproxy/script/script.py b/mitmproxy/script/script.py index 55778851b..edc17d43a 100644 --- a/mitmproxy/script/script.py +++ b/mitmproxy/script/script.py @@ -22,7 +22,15 @@ class Script(object): self.args = self.parse_command(command) self.ctx = context self.ns = None + + def __enter__(self): self.load() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_val: + return False # reraise the exception + self.unload() @property def filename(self): @@ -35,7 +43,7 @@ class Script(object): if os.name == "nt": # Windows: escape all backslashes in the path. backslashes = shlex.split(command, posix=False)[0].count("\\") command = command.replace("\\", "\\\\", backslashes) - args = shlex.split(command) + args = shlex.split(command) # pragma: nocover args[0] = os.path.expanduser(args[0]) if not os.path.exists(args[0]): raise ScriptException( @@ -58,7 +66,7 @@ class Script(object): ScriptException on failure """ if self.ns is not None: - self.unload() + raise ScriptException("Script is already loaded") script_dir = os.path.dirname(os.path.abspath(self.args[0])) self.ns = {'__file__': os.path.abspath(self.args[0])} sys.path.append(script_dir) diff --git a/test/mitmproxy/script/__init__.py b/test/mitmproxy/script/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/mitmproxy/script/test_script.py b/test/mitmproxy/script/test_script.py new file mode 100644 index 000000000..a9b559770 --- /dev/null +++ b/test/mitmproxy/script/test_script.py @@ -0,0 +1,83 @@ +from mitmproxy.script import Script +from mitmproxy.exceptions import ScriptException +from test.mitmproxy import tutils + + +class TestParseCommand: + def test_empty_command(self): + with tutils.raises(ScriptException): + Script.parse_command("") + + with tutils.raises(ScriptException): + Script.parse_command(" ") + + def test_no_script_file(self): + with tutils.raises("not found"): + Script.parse_command("notfound") + + with tutils.tmpdir() as dir: + with tutils.raises("not a file"): + Script.parse_command(dir) + + def test_parse_args(self): + with tutils.chdir(tutils.test_data.dirname): + assert Script.parse_command("scripts/a.py") == ["scripts/a.py"] + assert Script.parse_command("scripts/a.py foo bar") == ["scripts/a.py", "foo", "bar"] + assert Script.parse_command("scripts/a.py 'foo bar'") == ["scripts/a.py", "foo bar"] + + @tutils.skip_not_windows + def test_parse_windows(self): + with tutils.chdir(tutils.test_data.dirname): + assert Script.parse_command("scripts\\a.py") == ["scripts\\a.py"] + assert Script.parse_command("scripts\\a.py 'foo \\ bar'") == ["scripts\\a.py", 'foo \\ bar'] + + +def test_simple(): + with tutils.chdir(tutils.test_data.path("scripts")): + s = Script("a.py --var 42", None) + assert s.filename == "a.py" + assert s.ns is None + + s.load() + assert s.ns["var"] == 42 + + s.run("here") + assert s.ns["var"] == 43 + + s.unload() + assert s.ns is None + + with tutils.raises(ScriptException): + s.run("here") + + with Script("a.py --var 42", None) as s: + s.run("here") + + +def test_script_exception(): + with tutils.chdir(tutils.test_data.path("scripts")): + s = Script("syntaxerr.py", None) + with tutils.raises(ScriptException): + s.load() + + s = Script("starterr.py", None) + with tutils.raises(ScriptException): + s.load() + + s = Script("a.py", None) + s.load() + with tutils.raises(ScriptException): + s.load() + + s = Script("a.py", None) + with tutils.raises(ScriptException): + s.run("here") + + with tutils.raises(ScriptException): + with Script("reqerr.py", None) as s: + s.run("request", None) + + s = Script("unloaderr.py", None) + s.load() + with tutils.raises(ScriptException): + s.unload() diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index 803776ac8..b560d9a17 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -31,9 +31,8 @@ class DummyContext(object): def example(command): command = os.path.join(example_dir, command) ctx = DummyContext() - s = script.Script(command, ctx) - yield s - s.unload() + with script.Script(command, ctx) as s: + yield s def test_load_scripts(): @@ -52,8 +51,10 @@ def test_load_scripts(): f += " ~a" if "modify_response_body" in f: f += " foo bar" # two arguments required + + s = script.Script(f, script.ScriptContext(tmaster)) try: - s = script.Script(f, script.ScriptContext(tmaster)) # Loads the script file. + s.load() except Exception as v: if "ImportError" not in str(v): raise diff --git a/test/mitmproxy/test_script.py b/test/mitmproxy/test_script.py index b827c623c..b38ea041a 100644 --- a/test/mitmproxy/test_script.py +++ b/test/mitmproxy/test_script.py @@ -1,27 +1,9 @@ -import os import time import mock from mitmproxy import script, flow from . import tutils -def test_simple(): - s = flow.State() - fm = flow.FlowMaster(None, s) - sp = tutils.test_data.path("scripts/a.py") - p = script.Script("%s --var 40" % sp, script.ScriptContext(fm)) - - assert "here" in p.ns - assert p.run("here") == 41 - assert p.run("here") == 42 - - tutils.raises(script.ScriptException, p.run, "errargs") - - # Check reload - p.load() - assert p.run("here") == 41 - - def test_duplicate_flow(): s = flow.State() fm = flow.FlowMaster(None, s) @@ -33,35 +15,6 @@ def test_duplicate_flow(): assert fm.state.view[1].request.is_replay -def test_err(): - s = flow.State() - fm = flow.FlowMaster(None, s) - sc = script.ScriptContext(fm) - - tutils.raises( - "not found", - script.Script, "nonexistent", sc - ) - - tutils.raises( - "not a file", - script.Script, tutils.test_data.path("scripts"), sc - ) - - tutils.raises( - script.ScriptException, - script.Script, tutils.test_data.path("scripts/syntaxerr.py"), sc - ) - - tutils.raises( - script.ScriptException, - script.Script, tutils.test_data.path("scripts/loaderr.py"), sc - ) - - scr = script.Script(tutils.test_data.path("scripts/unloaderr.py"), sc) - tutils.raises(script.ScriptException, scr.unload) - - @tutils.skip_appveyor def test_concurrent(): s = flow.State() @@ -117,16 +70,6 @@ def test_concurrent2(): def test_concurrent_err(): s = flow.State() fm = flow.FlowMaster(None, s) - tutils.raises( - "Concurrent decorator not supported for 'start' method", - script.Script, - tutils.test_data.path("scripts/concurrent_decorator_err.py"), - fm) - - -def test_command_parsing(): - s = flow.State() - fm = flow.FlowMaster(None, s) - absfilepath = os.path.normcase(tutils.test_data.path("scripts/a.py")) - s = script.Script(absfilepath, script.ScriptContext(fm)) - assert os.path.isfile(s.args[0]) + with tutils.raises("Concurrent decorator not supported for 'start' method"): + s = script.Script(tutils.test_data.path("scripts/concurrent_decorator_err.py"), fm) + s.load() \ No newline at end of file diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py index 0d65df717..791db6d9e 100644 --- a/test/mitmproxy/tutils.py +++ b/test/mitmproxy/tutils.py @@ -26,6 +26,13 @@ def skip_windows(fn): return fn +def skip_not_windows(fn): + if os.name == "nt": + return fn + else: + return _skip_windows + + def _skip_appveyor(*args): raise SkipTest("Skipped on AppVeyor.") @@ -119,15 +126,18 @@ def get_body_line(last_displayed_body, line_nb): return last_displayed_body.contents()[line_nb + 2] +@contextmanager +def chdir(dir): + orig_dir = os.getcwd() + os.chdir(dir) + yield + os.chdir(orig_dir) + @contextmanager def tmpdir(*args, **kwargs): - orig_workdir = os.getcwd() temp_workdir = tempfile.mkdtemp(*args, **kwargs) - os.chdir(temp_workdir) - - yield temp_workdir - - os.chdir(orig_workdir) + with chdir(temp_workdir): + yield temp_workdir shutil.rmtree(temp_workdir)