diff --git a/mitmproxy/addons/core.py b/mitmproxy/addons/core.py index 8a63422d3..69df006f4 100644 --- a/mitmproxy/addons/core.py +++ b/mitmproxy/addons/core.py @@ -221,7 +221,7 @@ class Core: return ["gzip", "deflate", "br"] @command.command("options.load") - def options_load(self, path: str) -> None: + def options_load(self, path: command.Path) -> None: """ Load options from a file. """ @@ -233,7 +233,7 @@ class Core: ) from e @command.command("options.save") - def options_save(self, path: str) -> None: + def options_save(self, path: command.Path) -> None: """ Save options to a file. """ diff --git a/mitmproxy/addons/cut.py b/mitmproxy/addons/cut.py index a4a2107b7..5ec4c99eb 100644 --- a/mitmproxy/addons/cut.py +++ b/mitmproxy/addons/cut.py @@ -96,7 +96,7 @@ class Cut: return ret @command.command("cut.save") - def save(self, cuts: command.Cuts, path: str) -> None: + def save(self, cuts: command.Cuts, path: command.Path) -> None: """ Save cuts to file. If there are multiple rows or columns, the format is UTF-8 encoded CSV. If there is exactly one row and one column, @@ -107,7 +107,7 @@ class Cut: append = False if path.startswith("+"): append = True - path = path[1:] + path = command.Path(path[1:]) if len(cuts) == 1 and len(cuts[0]) == 1: with open(path, "ab" if append else "wb") as fp: if fp.tell() > 0: diff --git a/mitmproxy/addons/export.py b/mitmproxy/addons/export.py index fd0c830e5..5388a0e88 100644 --- a/mitmproxy/addons/export.py +++ b/mitmproxy/addons/export.py @@ -49,7 +49,7 @@ class Export(): return list(sorted(formats.keys())) @command.command("export.file") - def file(self, fmt: str, f: flow.Flow, path: str) -> None: + def file(self, fmt: str, f: flow.Flow, path: command.Path) -> None: """ Export a flow to path. """ diff --git a/mitmproxy/addons/save.py b/mitmproxy/addons/save.py index 5e739039e..40cd6f827 100644 --- a/mitmproxy/addons/save.py +++ b/mitmproxy/addons/save.py @@ -1,6 +1,7 @@ import os.path import typing +from mitmproxy import command from mitmproxy import exceptions from mitmproxy import flowfilter from mitmproxy import io @@ -48,7 +49,8 @@ class Save: if ctx.options.save_stream_file: self.start_stream_to_path(ctx.options.save_stream_file, self.filt) - def save(self, flows: typing.Sequence[flow.Flow], path: str) -> None: + @command.command("save.file") + def save(self, flows: typing.Sequence[flow.Flow], path: command.Path) -> None: """ Save flows to a file. If the path starts with a +, flows are appended to the file, otherwise it is over-written. @@ -63,9 +65,6 @@ class Save: f.close() ctx.log.alert("Saved %s flows." % len(flows)) - def load(self, l): - l.add_command("save.file", self.save) - def tcp_start(self, flow): if self.stream: self.active_flows.add(flow) diff --git a/mitmproxy/addons/view.py b/mitmproxy/addons/view.py index 6f0fd131b..8381e0250 100644 --- a/mitmproxy/addons/view.py +++ b/mitmproxy/addons/view.py @@ -351,13 +351,13 @@ class View(collections.Sequence): ctx.master.addons.trigger("update", updated) @command.command("view.load") - def load_file(self, path: str) -> None: + def load_file(self, path: command.Path) -> None: """ Load flows into the view, without processing them with addons. """ - path = os.path.expanduser(path) + spath = os.path.expanduser(path) try: - with open(path, "rb") as f: + with open(spath, "rb") as f: for i in io.FlowReader(f).stream(): # Do this to get a new ID, so we can load the same file N times and # get new flows each time. It would be more efficient to just have a diff --git a/mitmproxy/command.py b/mitmproxy/command.py index 25e00174c..a14c95d0b 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -37,6 +37,9 @@ class Choice(str): options_command = "" +Path = typing.NewType("Path", str) + + def typename(t: type, ret: bool) -> str: """ Translates a type to an explanatory string. If ret is True, we're @@ -44,8 +47,6 @@ def typename(t: type, ret: bool) -> str: """ if hasattr(t, "options_command"): return "choice" - elif issubclass(t, (str, int, bool)): - return t.__name__ elif t == typing.Sequence[flow.Flow]: return "[flow]" if ret else "flowspec" elif t == typing.Sequence[str]: @@ -54,6 +55,10 @@ def typename(t: type, ret: bool) -> str: return "[cuts]" if ret else "cutspec" elif t == flow.Flow: return "flow" + elif t == Path: + return "path" + elif issubclass(t, (str, int, bool)): + return t.__name__ else: # pragma: no cover raise NotImplementedError(t) @@ -186,7 +191,9 @@ def parsearg(manager: CommandManager, spec: str, argtype: type) -> typing.Any: "Invalid choice: see %s for options" % cmd ) return spec - if issubclass(argtype, str): + if argtype == Path: + return spec + elif issubclass(argtype, str): return spec elif argtype == bool: if spec == "true": diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 4b0bb00d8..a65b26e12 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -414,14 +414,14 @@ class ConsoleAddon: self._grideditor().cmd_delete() @command.command("console.grideditor.readfile") - def grideditor_readfile(self, path: str) -> None: + def grideditor_readfile(self, path: command.Path) -> None: """ Read a file into the currrent cell. """ self._grideditor().cmd_read_file(path) @command.command("console.grideditor.readfile_escaped") - def grideditor_readfile_escaped(self, path: str) -> None: + def grideditor_readfile_escaped(self, path: command.Path) -> None: """ Read a file containing a Python-style escaped stringinto the currrent cell. diff --git a/test/mitmproxy/test_command.py b/test/mitmproxy/test_command.py index cb9dc4ede..abb09ceaa 100644 --- a/test/mitmproxy/test_command.py +++ b/test/mitmproxy/test_command.py @@ -35,6 +35,9 @@ class TAddon: def choose(self, arg: TChoice) -> typing.Sequence[str]: # type: ignore return ["one", "two", "three"] + def path(self, arg: command.Path) -> None: + pass + class TestCommand: def test_varargs(self): @@ -97,6 +100,7 @@ def test_typename(): assert command.typename(typing.Sequence[str], False) == "[str]" assert command.typename(TChoice, False) == "choice" + assert command.typename(command.Path, False) == "path" class DummyConsole: @@ -156,6 +160,10 @@ def test_parsearg(): tctx.master.commands, "invalid", TChoice, ) + assert command.parsearg( + tctx.master.commands, "foo", command.Path + ) == "foo" + class TDec: @command.command("cmd1")