diff --git a/mitmproxy/command.py b/mitmproxy/command.py index 77494100f..8f0755254 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -204,7 +204,15 @@ class CommandManager(mitmproxy.types._CommandBase): return parse, remhelp - def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any: + def call(self, path: str, *args: typing.Sequence[typing.Any]) -> typing.Any: + """ + Call a command with native arguments. May raise CommandError. + """ + if path not in self.commands: + raise exceptions.CommandError("Unknown command: %s" % path) + return self.commands[path].func(*args) + + def call_strings(self, path: str, args: typing.Sequence[str]) -> typing.Any: """ Call a command using a list of string arguments. May raise CommandError. """ @@ -212,14 +220,14 @@ class CommandManager(mitmproxy.types._CommandBase): raise exceptions.CommandError("Unknown command: %s" % path) return self.commands[path].call(args) - def call(self, cmdstr: str): + def execute(self, cmdstr: str): """ - Call a command using a string. May raise CommandError. + Execute a command string. May raise CommandError. """ parts = list(lexer(cmdstr)) if not len(parts) >= 1: raise exceptions.CommandError("Invalid command: %s" % cmdstr) - return self.call_args(parts[0], parts[1:]) + return self.call_strings(parts[0], parts[1:]) def dump(self, out=sys.stdout) -> None: cmds = list(self.commands.values()) diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 2b9ff334c..81f568ed0 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -258,7 +258,7 @@ class ConsoleAddon: command, then invoke another command with all occurrences of {choice} replaced by the choice the user made. """ - choices = ctx.master.commands.call_args(choicecmd, []) + choices = ctx.master.commands.call_strings(choicecmd, []) def callback(opt): # We're now outside of the call context... @@ -514,7 +514,7 @@ class ConsoleAddon: raise exceptions.CommandError("Invalid flowview mode.") try: - self.master.commands.call_args( + self.master.commands.call_strings( "view.setval", ["@focus", "flowview_mode_%s" % idx, mode] ) @@ -537,7 +537,7 @@ class ConsoleAddon: if not fv: raise exceptions.CommandError("Not viewing a flow.") idx = fv.body.tab_offset - return self.master.commands.call_args( + return self.master.commands.call_strings( "view.getval", [ "@focus", diff --git a/mitmproxy/tools/web/app.py b/mitmproxy/tools/web/app.py index ea431d2f7..ae2394eb0 100644 --- a/mitmproxy/tools/web/app.py +++ b/mitmproxy/tools/web/app.py @@ -344,8 +344,7 @@ class ReplayFlow(RequestHandler): self.view.update([self.flow]) try: - self.master.command.call - self.master.replay_request(self.flow) + self.master.commands.call("replay.client", [self.flow]) except exceptions.ReplayException as e: raise APIError(400, str(e)) diff --git a/mitmproxy/types.py b/mitmproxy/types.py index 23320c121..283e7e2ec 100644 --- a/mitmproxy/types.py +++ b/mitmproxy/types.py @@ -47,10 +47,10 @@ class Choice: class _CommandBase: commands: typing.MutableMapping[str, typing.Any] = {} - def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any: + def call_strings(self, path: str, args: typing.Sequence[str]) -> typing.Any: raise NotImplementedError - def call(self, cmd: str) -> typing.Any: + def execute(self, cmd: str) -> typing.Any: raise NotImplementedError @@ -337,7 +337,7 @@ class _FlowType(_BaseFlowType): def parse(self, manager: _CommandBase, t: type, s: str) -> flow.Flow: try: - flows = manager.call_args("view.resolve", [s]) + flows = manager.call_strings("view.resolve", [s]) except exceptions.CommandError as e: raise exceptions.TypeError from e if len(flows) != 1: @@ -356,7 +356,7 @@ class _FlowsType(_BaseFlowType): def parse(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[flow.Flow]: try: - return manager.call_args("view.resolve", [s]) + return manager.call_strings("view.resolve", [s]) except exceptions.CommandError as e: raise exceptions.TypeError from e @@ -401,17 +401,17 @@ class _ChoiceType(_BaseType): display = "choice" def completion(self, manager: _CommandBase, t: Choice, s: str) -> typing.Sequence[str]: - return manager.call(t.options_command) + return manager.execute(t.options_command) def parse(self, manager: _CommandBase, t: Choice, s: str) -> str: - opts = manager.call(t.options_command) + opts = manager.execute(t.options_command) if s not in opts: raise exceptions.TypeError("Invalid choice.") return s def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: try: - opts = manager.call(typ.options_command) + opts = manager.execute(typ.options_command) except exceptions.CommandError: return False return val in opts diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index 616caf585..4aa1f6488 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -73,7 +73,7 @@ def test_save_command(tmpdir): v = view.View() tctx.master.addons.add(v) tctx.master.addons.add(sa) - tctx.master.commands.call_args("save.file", ["@shown", p]) + tctx.master.commands.call_strings("save.file", ["@shown", p]) def test_simple(tmpdir): diff --git a/test/mitmproxy/test_addonmanager.py b/test/mitmproxy/test_addonmanager.py index 796ae1bdd..1ef1521da 100644 --- a/test/mitmproxy/test_addonmanager.py +++ b/test/mitmproxy/test_addonmanager.py @@ -47,7 +47,7 @@ class AOption: def test_command(): with taddons.context() as tctx: tctx.master.addons.add(TAddon("test")) - assert tctx.master.commands.call("test.command") == "here" + assert tctx.master.commands.execute("test.command") == "here" def test_halt(): diff --git a/test/mitmproxy/test_command.py b/test/mitmproxy/test_command.py index 3d0a43f88..ea1017e7c 100644 --- a/test/mitmproxy/test_command.py +++ b/test/mitmproxy/test_command.py @@ -242,16 +242,19 @@ def test_simple(): a = TAddon() c.add("one.two", a.cmd1) assert c.commands["one.two"].help == "cmd1 help" - assert(c.call("one.two foo") == "ret foo") + assert(c.execute("one.two foo") == "ret foo") + assert(c.call("one.two", "foo") == "ret foo") + with pytest.raises(exceptions.CommandError, match="Unknown"): + c.execute("nonexistent") + with pytest.raises(exceptions.CommandError, match="Invalid"): + c.execute("") + with pytest.raises(exceptions.CommandError, match="argument mismatch"): + c.execute("one.two too many args") with pytest.raises(exceptions.CommandError, match="Unknown"): c.call("nonexistent") - with pytest.raises(exceptions.CommandError, match="Invalid"): - c.call("") - with pytest.raises(exceptions.CommandError, match="argument mismatch"): - c.call("one.two too many args") c.add("empty", a.empty) - c.call("empty") + c.execute("empty") fp = io.StringIO() c.dump(fp) @@ -340,13 +343,13 @@ def test_decorator(): a = TDec() c.collect_commands(a) assert "cmd1" in c.commands - assert c.call("cmd1 bar") == "ret bar" + assert c.execute("cmd1 bar") == "ret bar" assert "empty" in c.commands - assert c.call("empty") is None + assert c.execute("empty") is None with taddons.context() as tctx: tctx.master.addons.add(a) - assert tctx.master.commands.call("cmd1 bar") == "ret bar" + assert tctx.master.commands.execute("cmd1 bar") == "ret bar" def test_verify_arg_signature(): diff --git a/test/mitmproxy/tools/web/test_app.py b/test/mitmproxy/tools/web/test_app.py index 3d18987de..668d3c07d 100644 --- a/test/mitmproxy/tools/web/test_app.py +++ b/test/mitmproxy/tools/web/test_app.py @@ -184,15 +184,10 @@ class TestApp(tornado.testing.AsyncHTTPTestCase): self.fetch("/flows/42/revert", method="POST") assert not f._backup - # FIXME - # def test_flow_replay(self): - # with mock.patch("mitmproxy.master.Master.replay_request") as replay_request: - # assert self.fetch("/flows/42/replay", method="POST").code == 200 - # assert replay_request.called - # replay_request.side_effect = exceptions.ReplayException( - # "out of replays" - # ) - # assert self.fetch("/flows/42/replay", method="POST").code == 400 + def test_flow_replay(self): + with mock.patch("mitmproxy.command.CommandManager.call") as replay_call: + assert self.fetch("/flows/42/replay", method="POST").code == 200 + assert replay_call.called def test_flow_content(self): f = self.view.get_by_id("42")