diff --git a/docs/scripts/pdoc-template/module.html.jinja2 b/docs/scripts/pdoc-template/module.html.jinja2 index dbe3baf15..268c393fb 100644 --- a/docs/scripts/pdoc-template/module.html.jinja2 +++ b/docs/scripts/pdoc-template/module.html.jinja2 @@ -55,6 +55,10 @@ To document all event hooks, we do a bit of hackery: {% if doc.qualname.startswith("ServerConnectionHookData") and doc.name != "__init__" %} {{ default_is_public(doc) }} {% endif %} + {% elif doc.modulename == "mitmproxy.websocket" %} + {% if doc.qualname != "WebSocketMessage.type" %} + {{ default_is_public(doc) }} + {% endif %} {% else %} {{ default_is_public(doc) }} {% endif %} diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index f7a60918a..3d9f258fe 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -8,7 +8,7 @@ from mitmproxy.proxy import commands, events, server_hooks from mitmproxy.proxy import server from mitmproxy.proxy.layers.tcp import TcpMessageInjected from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected -from mitmproxy.utils import asyncio_utils, human, strutils +from mitmproxy.utils import asyncio_utils, human from wsproto.frame_protocol import Opcode @@ -190,15 +190,14 @@ class Proxyserver: self._connections[event.flow.client_conn.peername].server_event(event) @command.command("inject.websocket") - def inject_websocket(self, flow: Flow, to_client: bool, message: str, is_text: bool = True): + def inject_websocket(self, flow: Flow, to_client: bool, message: bytes, is_text: bool = True): if not isinstance(flow, http.HTTPFlow) or not flow.websocket: ctx.log.warn("Cannot inject WebSocket messages into non-WebSocket flows.") - message_bytes = strutils.escaped_str_to_bytes(message) msg = websocket.WebSocketMessage( Opcode.TEXT if is_text else Opcode.BINARY, not to_client, - message_bytes + message ) event = WebSocketMessageInjected(flow, msg) try: @@ -207,12 +206,11 @@ class Proxyserver: ctx.log.warn(str(e)) @command.command("inject.tcp") - def inject_tcp(self, flow: Flow, to_client: bool, message: str): + def inject_tcp(self, flow: Flow, to_client: bool, message: bytes): if not isinstance(flow, tcp.TCPFlow): ctx.log.warn("Cannot inject TCP messages into non-TCP flows.") - message_bytes = strutils.escaped_str_to_bytes(message) - event = TcpMessageInjected(flow, tcp.TCPMessage(not to_client, message_bytes)) + event = TcpMessageInjected(flow, tcp.TCPMessage(not to_client, message)) try: self.inject_event(event) except ValueError as e: diff --git a/mitmproxy/command.py b/mitmproxy/command.py index 23aed219d..d67e9a424 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -73,7 +73,7 @@ class Command: for name, parameter in self.signature.parameters.items(): t = parameter.annotation if not mitmproxy.types.CommandTypes.get(parameter.annotation, None): - raise exceptions.CommandError(f"Argument {name} has an unknown type ({_empty_as_none(t)}) in {func}.") + raise exceptions.CommandError(f"Argument {name} has an unknown type {t} in {func}.") if self.return_type and not mitmproxy.types.CommandTypes.get(self.return_type, None): raise exceptions.CommandError(f"Return type has an unknown type ({self.return_type}) in {func}.") @@ -106,8 +106,15 @@ class Command: raise exceptions.CommandError(f"Command argument mismatch: \n {expected}\n {received}") for name, value in bound_arguments.arguments.items(): - convert_to = self.signature.parameters[name].annotation - bound_arguments.arguments[name] = parsearg(self.manager, value, convert_to) + param = self.signature.parameters[name] + convert_to = param.annotation + if param.kind == param.VAR_POSITIONAL: + bound_arguments.arguments[name] = tuple( + parsearg(self.manager, x, convert_to) + for x in value + ) + else: + bound_arguments.arguments[name] = parsearg(self.manager, value, convert_to) bound_arguments.apply_defaults() diff --git a/mitmproxy/command_lexer.py b/mitmproxy/command_lexer.py index 31458f4d2..2ba691dff 100644 --- a/mitmproxy/command_lexer.py +++ b/mitmproxy/command_lexer.py @@ -1,4 +1,3 @@ -import ast import re import pyparsing @@ -10,13 +9,9 @@ import pyparsing PartialQuotedString = pyparsing.Regex( re.compile( r''' - (["']) # start quote - (?: - (?:\\.) # escape sequence - | - (?!\1). # unescaped character that is not our quote nor the begin of an escape sequence. We can't use \1 in [] - )* - (?:\1|$) # end quote + "[^"]*(?:"|$) # double-quoted string that ends with double quote or EOF + | + '[^']*(?:'|$) # single-quoted string that ends with double quote or EOF ''', re.VERBOSE ) @@ -32,18 +27,15 @@ expr = pyparsing.ZeroOrMore( def quote(val: str) -> str: if val and all(char not in val for char in "'\" \r\n\t"): return val - return repr(val) # TODO: More of a hack. + if '"' not in val: + return f'"{val}"' + if "'" not in val: + return f"'{val}'" + return '"' + val.replace('"', r"\x22") + '"' def unquote(x: str) -> str: - quoted = ( - (x.startswith('"') and x.endswith('"')) - or - (x.startswith("'") and x.endswith("'")) - ) - if quoted: - try: - x = ast.literal_eval(x) - except Exception: - x = x[1:-1] - return x + if len(x) > 1 and x[0] in "'\"" and x[0] == x[-1]: + return x[1:-1] + else: + return x diff --git a/mitmproxy/tools/console/grideditor/col_bytes.py b/mitmproxy/tools/console/grideditor/col_bytes.py index 990253ea4..e27a4474a 100644 --- a/mitmproxy/tools/console/grideditor/col_bytes.py +++ b/mitmproxy/tools/console/grideditor/col_bytes.py @@ -34,8 +34,8 @@ class Display(base.Cell): class Edit(base.Cell): def __init__(self, data: bytes) -> None: - data = strutils.bytes_to_escaped_str(data) - w = urwid.Edit(edit_text=data, wrap="any", multiline=True) + d = strutils.bytes_to_escaped_str(data) + w = urwid.Edit(edit_text=d, wrap="any", multiline=True) w = urwid.AttrWrap(w, "editfield") super().__init__(w) diff --git a/mitmproxy/types.py b/mitmproxy/types.py index e02a1f670..01d7ac47f 100644 --- a/mitmproxy/types.py +++ b/mitmproxy/types.py @@ -1,10 +1,12 @@ +import codecs import os import glob +import re import typing from mitmproxy import exceptions from mitmproxy import flow -from mitmproxy.utils import emoji +from mitmproxy.utils import emoji, strutils if typing.TYPE_CHECKING: # pragma: no cover from mitmproxy.command import CommandManager @@ -104,16 +106,52 @@ class _StrType(_BaseType): typ = str display = "str" + # https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals + escape_sequences = re.compile(r""" + \\ ( + [\\'"abfnrtv] # Standard C escape sequence + | [0-7]{1,3} # Character with octal value + | x.. # Character with hex value + | N{[^}]+} # Character name in the Unicode database + | u.... # Character with 16-bit hex value + | U........ # Character with 32-bit hex value + ) + """, re.VERBOSE) + + @staticmethod + def _unescape(match: re.Match) -> str: + return codecs.decode(match.group(0), "unicode-escape") # type: ignore + def completion(self, manager: "CommandManager", t: type, s: str) -> typing.Sequence[str]: return [] def parse(self, manager: "CommandManager", t: type, s: str) -> str: - return s + try: + return self.escape_sequences.sub(self._unescape, s) + except ValueError as e: + raise exceptions.TypeError(f"Invalid str: {e}") from e def is_valid(self, manager: "CommandManager", typ: typing.Any, val: typing.Any) -> bool: return isinstance(val, str) +class _BytesType(_BaseType): + typ = bytes + display = "bytes" + + def completion(self, manager: "CommandManager", t: type, s: str) -> typing.Sequence[str]: + return [] + + def parse(self, manager: "CommandManager", t: type, s: str) -> bytes: + try: + return strutils.escaped_str_to_bytes(s) + except ValueError as e: + raise exceptions.TypeError(str(e)) + + def is_valid(self, manager: "CommandManager", typ: typing.Any, val: typing.Any) -> bool: + return isinstance(val, bytes) + + class _UnknownType(_BaseType): typ = Unknown display = "unknown" @@ -460,4 +498,5 @@ CommandTypes = TypeManager( _PathType, _StrType, _StrSeqType, + _BytesType, ) diff --git a/mitmproxy/utils/strutils.py b/mitmproxy/utils/strutils.py index 3debb2aa3..0622b737d 100644 --- a/mitmproxy/utils/strutils.py +++ b/mitmproxy/utils/strutils.py @@ -79,7 +79,7 @@ def escape_control_characters(text: str, keep_spacing=True) -> str: return text.translate(trans) -def bytes_to_escaped_str(data, keep_spacing=False, escape_single_quotes=False): +def bytes_to_escaped_str(data: bytes, keep_spacing: bool = False, escape_single_quotes: bool = False) -> str: """ Take bytes and return a safe string that can be displayed to the user. @@ -107,7 +107,7 @@ def bytes_to_escaped_str(data, keep_spacing=False, escape_single_quotes=False): return ret -def escaped_str_to_bytes(data): +def escaped_str_to_bytes(data: str) -> bytes: """ Take an escaped string and return the unescaped bytes equivalent. @@ -119,7 +119,7 @@ def escaped_str_to_bytes(data): # This one is difficult - we use an undocumented Python API here # as per http://stackoverflow.com/a/23151714/934719 - return codecs.escape_decode(data)[0] + return codecs.escape_decode(data)[0] # type: ignore def is_mostly_bin(s: bytes) -> bool: diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 7fd0fcb02..00accbf55 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -20,18 +20,16 @@ class WebSocketMessage(serializable.Serializable): """ A single WebSocket message sent from one peer to the other. - Fragmented WebSocket messages are reassembled by mitmproxy and the + Fragmented WebSocket messages are reassembled by mitmproxy and then represented as a single instance of this class. The [WebSocket RFC](https://tools.ietf.org/html/rfc6455) specifies both text and binary messages. To avoid a whole class of nasty type confusion bugs, - mitmproxy stores all message contents as binary. If you need text, you can decode the `content` property: + mitmproxy stores all message contents as `bytes`. If you need a `str`, you can access the `text` property + on text messages: - >>> from wsproto.frame_protocol import Opcode - >>> if message.type == Opcode.TEXT: - >>> text = message.content.decode() - - Per the WebSocket spec, text messages always use UTF-8 encoding. + >>> if message.is_text: + >>> text = message.text """ from_client: bool @@ -40,8 +38,7 @@ class WebSocketMessage(serializable.Serializable): """ The message type, as per RFC 6455's [opcode](https://tools.ietf.org/html/rfc6455#section-5.2). - Note that mitmproxy will always store the message contents as *bytes*. - A dedicated `.text` property for text messages is planned, see https://github.com/mitmproxy/mitmproxy/pull/4486. + Mitmproxy currently only exposes messages assembled from `TEXT` and `BINARY` frames. """ content: bytes """A byte-string representing the content of this message.""" @@ -81,10 +78,39 @@ class WebSocketMessage(serializable.Serializable): else: return repr(self.content) + @property + def is_text(self) -> bool: + """ + `True` if this message is assembled from WebSocket `TEXT` frames, + `False` if it is assembled from `BINARY` frames. + """ + return self.type == Opcode.TEXT + def kill(self): # Likely to be replaced with .drop() in the future, see https://github.com/mitmproxy/mitmproxy/pull/4486 self.killed = True + @property + def text(self) -> str: + """ + The message content as text. + + This attribute is only available if `WebSocketMessage.is_text` is `True`. + + *See also:* `WebSocketMessage.content` + """ + if self.type != Opcode.TEXT: + raise AttributeError(f"{self.type.name.title()} WebSocket frames do not have a 'text' attribute.") + + return self.content.decode() + + @text.setter + def text(self, value: str) -> None: + if self.type != Opcode.TEXT: + raise AttributeError(f"{self.type.name.title()} WebSocket frames do not have a 'text' attribute.") + + self.content = value.encode() + class WebSocketData(stateobject.StateObject): """ @@ -97,9 +123,9 @@ class WebSocketData(stateobject.StateObject): closed_by_client: Optional[bool] = None """ - True if the client closed the connection, - False if the server closed the connection, - None if the connection is active. + `True` if the client closed the connection, + `False` if the server closed the connection, + `None` if the connection is active. """ close_code: Optional[int] = None """[Close Code](https://tools.ietf.org/html/rfc6455#section-7.1.5)""" diff --git a/test/mitmproxy/addons/test_proxyserver.py b/test/mitmproxy/addons/test_proxyserver.py index e2d8d3c4b..be644a963 100644 --- a/test/mitmproxy/addons/test_proxyserver.py +++ b/test/mitmproxy/addons/test_proxyserver.py @@ -90,7 +90,7 @@ async def test_start_stop(): @pytest.mark.asyncio -async def test_inject(): +async def test_inject() -> None: async def server_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): while s := await reader.read(1): writer.write(s.upper()) @@ -112,39 +112,39 @@ async def test_inject(): writer.write(b"a") assert await reader.read(1) == b"A" - ps.inject_tcp(state.flows[0], False, "b") + ps.inject_tcp(state.flows[0], False, b"b") assert await reader.read(1) == b"B" - ps.inject_tcp(state.flows[0], True, "c") + ps.inject_tcp(state.flows[0], True, b"c") assert await reader.read(1) == b"c" @pytest.mark.asyncio -async def test_inject_fail(): +async def test_inject_fail() -> None: ps = Proxyserver() with taddons.context(ps) as tctx: ps.inject_websocket( tflow.tflow(), True, - "test" + b"test" ) await tctx.master.await_log("Cannot inject WebSocket messages into non-WebSocket flows.", level="warn") ps.inject_tcp( tflow.tflow(), True, - "test" + b"test" ) await tctx.master.await_log("Cannot inject TCP messages into non-TCP flows.", level="warn") ps.inject_websocket( tflow.twebsocketflow(), True, - "test" + b"test" ) await tctx.master.await_log("Flow is not from a live connection.", level="warn") ps.inject_websocket( tflow.ttcpflow(), True, - "test" + b"test" ) await tctx.master.await_log("Flow is not from a live connection.", level="warn") diff --git a/test/mitmproxy/test_command.py b/test/mitmproxy/test_command.py index 808cc18b0..a74503841 100644 --- a/test/mitmproxy/test_command.py +++ b/test/mitmproxy/test_command.py @@ -367,24 +367,6 @@ class TestCommand: ], [], ], - [ - r'cmd13 "a \"b\" c"', - [ - command.ParseResult(value="cmd13", type=mitmproxy.types.Cmd, valid=False), - command.ParseResult(value=" ", type=mitmproxy.types.Space, valid=True), - command.ParseResult(value=r'"a \"b\" c"', type=mitmproxy.types.Unknown, valid=False), - ], - [], - ], - [ - r"cmd14 'a \'b\' c'", - [ - command.ParseResult(value="cmd14", type=mitmproxy.types.Cmd, valid=False), - command.ParseResult(value=" ", type=mitmproxy.types.Space, valid=True), - command.ParseResult(value=r"'a \'b\' c'", type=mitmproxy.types.Unknown, valid=False), - ], - [], - ], [ " spaces_at_the_begining_are_not_stripped", [ @@ -436,12 +418,6 @@ def test_simple(): c.call("nonexistent") with pytest.raises(exceptions.CommandError, match="Unknown"): c.execute("\\") - with pytest.raises(exceptions.CommandError, match="Unknown"): - c.execute(r"\'") - with pytest.raises(exceptions.CommandError, match="Unknown"): - c.execute(r"\"") - with pytest.raises(exceptions.CommandError, match="Unknown"): - c.execute(r"\"") c.add("empty", a.empty) c.execute("empty") diff --git a/test/mitmproxy/test_command_lexer.py b/test/mitmproxy/test_command_lexer.py index ec9940874..f94255cdb 100644 --- a/test/mitmproxy/test_command_lexer.py +++ b/test/mitmproxy/test_command_lexer.py @@ -11,7 +11,6 @@ from mitmproxy import command_lexer ("'foo'", True), ('"foo"', True), ("'foo' bar'", False), - ("'foo\\' bar'", True), ("'foo' 'bar'", False), ("'foo'x", False), ('''"foo ''', True), @@ -43,8 +42,19 @@ def test_expr(test_input, expected): @given(text()) +@example(r"foo") +@example(r"'foo\''") +@example(r"'foo\"'") +@example(r'"foo\""') +@example(r'"foo\'"') +@example("'foo\\'") +@example("'foo\\\\'") +@example("\"foo\\'\"") +@example("\"foo\\\\'\"") +@example('\'foo\\"\'') +@example(r"\\\foo") def test_quote_unquote_cycle(s): - assert command_lexer.unquote(command_lexer.quote(s)) == s + assert command_lexer.unquote(command_lexer.quote(s)).replace(r"\x22", '"') == s @given(text()) diff --git a/test/mitmproxy/test_types.py b/test/mitmproxy/test_types.py index 1df108f17..e0f56995f 100644 --- a/test/mitmproxy/test_types.py +++ b/test/mitmproxy/test_types.py @@ -40,6 +40,21 @@ def test_str(): assert b.is_valid(tctx.master.commands, str, 1) is False assert b.completion(tctx.master.commands, str, "") == [] assert b.parse(tctx.master.commands, str, "foo") == "foo" + assert b.parse(tctx.master.commands, str, r"foo\nbar") == "foo\nbar" + assert b.parse(tctx.master.commands, str, r"\N{BELL}") == "🔔" + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, bool, r"\N{UNKNOWN UNICODE SYMBOL!}") + + +def test_bytes(): + with taddons.context() as tctx: + b = mitmproxy.types._BytesType() + assert b.is_valid(tctx.master.commands, bytes, b"foo") is True + assert b.is_valid(tctx.master.commands, bytes, 1) is False + assert b.completion(tctx.master.commands, bytes, "") == [] + assert b.parse(tctx.master.commands, bytes, "foo") == b"foo" + with pytest.raises(mitmproxy.exceptions.TypeError): + b.parse(tctx.master.commands, bytes, "incomplete escape sequence\\") def test_unknown(): diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index 26d7b0f6a..d7404f290 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -1,3 +1,5 @@ +import pytest + from mitmproxy import http from mitmproxy import websocket from mitmproxy.test import tflow @@ -26,3 +28,18 @@ class TestWebSocketMessage: assert not m.killed m.kill() assert m.killed + + def test_text(self): + txt = websocket.WebSocketMessage(Opcode.TEXT, True, b"foo") + bin = websocket.WebSocketMessage(Opcode.BINARY, True, b"foo") + + assert txt.is_text + assert txt.text == "foo" + txt.text = "bar" + assert txt.content == b"bar" + + assert not bin.is_text + with pytest.raises(AttributeError, match="do not have a 'text' attribute."): + _ = bin.text + with pytest.raises(AttributeError, match="do not have a 'text' attribute."): + bin.text = "bar" diff --git a/test/mitmproxy/tools/console/test_defaultkeys.py b/test/mitmproxy/tools/console/test_defaultkeys.py index f87c20224..389a8eefe 100644 --- a/test/mitmproxy/tools/console/test_defaultkeys.py +++ b/test/mitmproxy/tools/console/test_defaultkeys.py @@ -20,18 +20,18 @@ async def test_commands_exist(): await m.load_flow(tflow()) for binding in km.bindings: - parsed, _ = command_manager.parse_partial(binding.command.strip()) - - cmd = parsed[0].value - args = [ - a.value for a in parsed[1:] - if a.type != mitmproxy.types.Space - ] - - assert cmd in m.commands.commands - - cmd_obj = m.commands.commands[cmd] try: + parsed, _ = command_manager.parse_partial(binding.command.strip()) + + cmd = parsed[0].value + args = [ + a.value for a in parsed[1:] + if a.type != mitmproxy.types.Space + ] + + assert cmd in m.commands.commands + + cmd_obj = m.commands.commands[cmd] cmd_obj.prepare_args(args) except Exception as e: - raise ValueError(f"Invalid command: {binding.command}") from e + raise ValueError(f"Invalid binding: {binding.command}") from e