diff --git a/mitmproxy/command.py b/mitmproxy/command.py index 8f0755254..c94e8abb2 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -35,9 +35,11 @@ def typename(t: type) -> str: """ Translates a type to an explanatory string. """ + if t == inspect._empty: # type: ignore + raise exceptions.CommandError("missing type annotation") to = mitmproxy.types.CommandTypes.get(t, None) if not to: - raise NotImplementedError(t) + raise exceptions.CommandError("unsupported type: %s" % getattr(t, "__name__", t)) return to.display @@ -58,7 +60,12 @@ class Command: if i.kind == i.VAR_POSITIONAL: self.has_positional = True self.paramtypes = [v.annotation for v in sig.parameters.values()] - self.returntype = sig.return_annotation + if sig.return_annotation == inspect._empty: # type: ignore + self.returntype = None + else: + self.returntype = sig.return_annotation + # This fails with a CommandException if types are invalid + self.signature_help() def paramnames(self) -> typing.Sequence[str]: v = [typename(i) for i in self.paramtypes] @@ -133,7 +140,12 @@ class CommandManager(mitmproxy.types._CommandBase): pass # hasattr may raise if o implements __getattr__. else: if is_command: - self.add(o.command_path, o) + try: + self.add(o.command_path, o) + except exceptions.CommandError as e: + self.master.log.warn( + "Could not load command %s: %s" % (o.command_path, e) + ) def add(self, path: str, func: typing.Callable): self.commands[path] = Command(self, path, func) diff --git a/mitmproxy/tools/console/commands.py b/mitmproxy/tools/console/commands.py index 1183ee9d4..0f35742b3 100644 --- a/mitmproxy/tools/console/commands.py +++ b/mitmproxy/tools/console/commands.py @@ -46,12 +46,13 @@ class CommandItem(urwid.WidgetWrap): class CommandListWalker(urwid.ListWalker): def __init__(self, master): self.master = master - self.index = 0 - self.focusobj = None - self.cmds = list(master.commands.commands.values()) + self.refresh() + + def refresh(self): + self.cmds = list(self.master.commands.commands.values()) self.cmds.sort(key=lambda x: x.signature_help()) - self.set_focus(0) + self.set_focus(self.index) def get_edit_text(self): return self.focus_obj.get_edit_text() @@ -137,6 +138,9 @@ class Commands(urwid.Pile, layoutwidget.LayoutWidget): ) self.master = master + def layout_pushed(self, prev): + self.widget_list[0].walker.refresh() + def keypress(self, size, key): if key == "m_next": self.focus_position = ( diff --git a/test/mitmproxy/test_command.py b/test/mitmproxy/test_command.py index ea1017e7c..7c0dc06d9 100644 --- a/test/mitmproxy/test_command.py +++ b/test/mitmproxy/test_command.py @@ -1,4 +1,5 @@ import typing +import inspect from mitmproxy import command from mitmproxy import flow from mitmproxy import exceptions @@ -55,7 +56,35 @@ class TAddon: pass +class Unsupported: + pass + + +class TypeErrAddon: + @command.command("noret") + def noret(self): + pass + + @command.command("invalidret") + def invalidret(self) -> Unsupported: + pass + + @command.command("invalidarg") + def invalidarg(self, u: Unsupported): + pass + + class TestCommand: + def test_typecheck(self): + with taddons.context(loadcore=False) as tctx: + cm = command.CommandManager(tctx.master) + a = TypeErrAddon() + command.Command(cm, "noret", a.noret) + with pytest.raises(exceptions.CommandError): + command.Command(cm, "invalidret", a.invalidret) + with pytest.raises(exceptions.CommandError): + command.Command(cm, "invalidarg", a.invalidarg) + def test_varargs(self): with taddons.context() as tctx: cm = command.CommandManager(tctx.master) @@ -275,6 +304,11 @@ def test_typename(): assert command.typename(mitmproxy.types.Path) == "path" assert command.typename(mitmproxy.types.Cmd) == "cmd" + with pytest.raises(exceptions.CommandError, match="missing type annotation"): + command.typename(inspect._empty) + with pytest.raises(exceptions.CommandError, match="unsupported type"): + command.typename(None) + class DummyConsole: @command.command("view.resolve") @@ -326,7 +360,8 @@ class TCmds(TAttr): pass -def test_collect_commands(): +@pytest.mark.asyncio +async def test_collect_commands(): """ This tests for the error thrown by hasattr() """ @@ -336,6 +371,10 @@ def test_collect_commands(): c.collect_commands(a) assert "empty" in c.commands + a = TypeErrAddon() + c.collect_commands(a) + await tctx.master.await_log("Could not load") + def test_decorator(): with taddons.context() as tctx: