Merge pull request #4650 from mhils/prinzhorn

[WIP] Fix WebSocket/TCP injection
This commit is contained in:
Maximilian Hils 2021-07-15 13:18:33 +02:00 committed by GitHub
commit 5b4ac96f4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 179 additions and 95 deletions

View File

@ -55,6 +55,10 @@ To document all event hooks, we do a bit of hackery:
{% if doc.qualname.startswith("ServerConnectionHookData") and doc.name != "__init__" %}
{{ default_is_public(doc) }}
{% endif %}
{% elif doc.modulename == "mitmproxy.websocket" %}
{% if doc.qualname != "WebSocketMessage.type" %}
{{ default_is_public(doc) }}
{% endif %}
{% else %}
{{ default_is_public(doc) }}
{% endif %}

View File

@ -8,7 +8,7 @@ from mitmproxy.proxy import commands, events, server_hooks
from mitmproxy.proxy import server
from mitmproxy.proxy.layers.tcp import TcpMessageInjected
from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected
from mitmproxy.utils import asyncio_utils, human, strutils
from mitmproxy.utils import asyncio_utils, human
from wsproto.frame_protocol import Opcode
@ -190,15 +190,14 @@ class Proxyserver:
self._connections[event.flow.client_conn.peername].server_event(event)
@command.command("inject.websocket")
def inject_websocket(self, flow: Flow, to_client: bool, message: str, is_text: bool = True):
def inject_websocket(self, flow: Flow, to_client: bool, message: bytes, is_text: bool = True):
if not isinstance(flow, http.HTTPFlow) or not flow.websocket:
ctx.log.warn("Cannot inject WebSocket messages into non-WebSocket flows.")
message_bytes = strutils.escaped_str_to_bytes(message)
msg = websocket.WebSocketMessage(
Opcode.TEXT if is_text else Opcode.BINARY,
not to_client,
message_bytes
message
)
event = WebSocketMessageInjected(flow, msg)
try:
@ -207,12 +206,11 @@ class Proxyserver:
ctx.log.warn(str(e))
@command.command("inject.tcp")
def inject_tcp(self, flow: Flow, to_client: bool, message: str):
def inject_tcp(self, flow: Flow, to_client: bool, message: bytes):
if not isinstance(flow, tcp.TCPFlow):
ctx.log.warn("Cannot inject TCP messages into non-TCP flows.")
message_bytes = strutils.escaped_str_to_bytes(message)
event = TcpMessageInjected(flow, tcp.TCPMessage(not to_client, message_bytes))
event = TcpMessageInjected(flow, tcp.TCPMessage(not to_client, message))
try:
self.inject_event(event)
except ValueError as e:

View File

@ -73,7 +73,7 @@ class Command:
for name, parameter in self.signature.parameters.items():
t = parameter.annotation
if not mitmproxy.types.CommandTypes.get(parameter.annotation, None):
raise exceptions.CommandError(f"Argument {name} has an unknown type ({_empty_as_none(t)}) in {func}.")
raise exceptions.CommandError(f"Argument {name} has an unknown type {t} in {func}.")
if self.return_type and not mitmproxy.types.CommandTypes.get(self.return_type, None):
raise exceptions.CommandError(f"Return type has an unknown type ({self.return_type}) in {func}.")
@ -106,8 +106,15 @@ class Command:
raise exceptions.CommandError(f"Command argument mismatch: \n {expected}\n {received}")
for name, value in bound_arguments.arguments.items():
convert_to = self.signature.parameters[name].annotation
bound_arguments.arguments[name] = parsearg(self.manager, value, convert_to)
param = self.signature.parameters[name]
convert_to = param.annotation
if param.kind == param.VAR_POSITIONAL:
bound_arguments.arguments[name] = tuple(
parsearg(self.manager, x, convert_to)
for x in value
)
else:
bound_arguments.arguments[name] = parsearg(self.manager, value, convert_to)
bound_arguments.apply_defaults()

View File

@ -1,4 +1,3 @@
import ast
import re
import pyparsing
@ -10,13 +9,9 @@ import pyparsing
PartialQuotedString = pyparsing.Regex(
re.compile(
r'''
(["']) # start quote
(?:
(?:\\.) # escape sequence
|
(?!\1). # unescaped character that is not our quote nor the begin of an escape sequence. We can't use \1 in []
)*
(?:\1|$) # end quote
"[^"]*(?:"|$) # double-quoted string that ends with double quote or EOF
|
'[^']*(?:'|$) # single-quoted string that ends with double quote or EOF
''',
re.VERBOSE
)
@ -32,18 +27,15 @@ expr = pyparsing.ZeroOrMore(
def quote(val: str) -> str:
if val and all(char not in val for char in "'\" \r\n\t"):
return val
return repr(val) # TODO: More of a hack.
if '"' not in val:
return f'"{val}"'
if "'" not in val:
return f"'{val}'"
return '"' + val.replace('"', r"\x22") + '"'
def unquote(x: str) -> str:
quoted = (
(x.startswith('"') and x.endswith('"'))
or
(x.startswith("'") and x.endswith("'"))
)
if quoted:
try:
x = ast.literal_eval(x)
except Exception:
x = x[1:-1]
return x
if len(x) > 1 and x[0] in "'\"" and x[0] == x[-1]:
return x[1:-1]
else:
return x

View File

@ -34,8 +34,8 @@ class Display(base.Cell):
class Edit(base.Cell):
def __init__(self, data: bytes) -> None:
data = strutils.bytes_to_escaped_str(data)
w = urwid.Edit(edit_text=data, wrap="any", multiline=True)
d = strutils.bytes_to_escaped_str(data)
w = urwid.Edit(edit_text=d, wrap="any", multiline=True)
w = urwid.AttrWrap(w, "editfield")
super().__init__(w)

View File

@ -1,10 +1,12 @@
import codecs
import os
import glob
import re
import typing
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.utils import emoji
from mitmproxy.utils import emoji, strutils
if typing.TYPE_CHECKING: # pragma: no cover
from mitmproxy.command import CommandManager
@ -104,16 +106,52 @@ class _StrType(_BaseType):
typ = str
display = "str"
# https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals
escape_sequences = re.compile(r"""
\\ (
[\\'"abfnrtv] # Standard C escape sequence
| [0-7]{1,3} # Character with octal value
| x.. # Character with hex value
| N{[^}]+} # Character name in the Unicode database
| u.... # Character with 16-bit hex value
| U........ # Character with 32-bit hex value
)
""", re.VERBOSE)
@staticmethod
def _unescape(match: re.Match) -> str:
return codecs.decode(match.group(0), "unicode-escape") # type: ignore
def completion(self, manager: "CommandManager", t: type, s: str) -> typing.Sequence[str]:
return []
def parse(self, manager: "CommandManager", t: type, s: str) -> str:
return s
try:
return self.escape_sequences.sub(self._unescape, s)
except ValueError as e:
raise exceptions.TypeError(f"Invalid str: {e}") from e
def is_valid(self, manager: "CommandManager", typ: typing.Any, val: typing.Any) -> bool:
return isinstance(val, str)
class _BytesType(_BaseType):
typ = bytes
display = "bytes"
def completion(self, manager: "CommandManager", t: type, s: str) -> typing.Sequence[str]:
return []
def parse(self, manager: "CommandManager", t: type, s: str) -> bytes:
try:
return strutils.escaped_str_to_bytes(s)
except ValueError as e:
raise exceptions.TypeError(str(e))
def is_valid(self, manager: "CommandManager", typ: typing.Any, val: typing.Any) -> bool:
return isinstance(val, bytes)
class _UnknownType(_BaseType):
typ = Unknown
display = "unknown"
@ -460,4 +498,5 @@ CommandTypes = TypeManager(
_PathType,
_StrType,
_StrSeqType,
_BytesType,
)

View File

@ -79,7 +79,7 @@ def escape_control_characters(text: str, keep_spacing=True) -> str:
return text.translate(trans)
def bytes_to_escaped_str(data, keep_spacing=False, escape_single_quotes=False):
def bytes_to_escaped_str(data: bytes, keep_spacing: bool = False, escape_single_quotes: bool = False) -> str:
"""
Take bytes and return a safe string that can be displayed to the user.
@ -107,7 +107,7 @@ def bytes_to_escaped_str(data, keep_spacing=False, escape_single_quotes=False):
return ret
def escaped_str_to_bytes(data):
def escaped_str_to_bytes(data: str) -> bytes:
"""
Take an escaped string and return the unescaped bytes equivalent.
@ -119,7 +119,7 @@ def escaped_str_to_bytes(data):
# This one is difficult - we use an undocumented Python API here
# as per http://stackoverflow.com/a/23151714/934719
return codecs.escape_decode(data)[0]
return codecs.escape_decode(data)[0] # type: ignore
def is_mostly_bin(s: bytes) -> bool:

View File

@ -20,18 +20,16 @@ class WebSocketMessage(serializable.Serializable):
"""
A single WebSocket message sent from one peer to the other.
Fragmented WebSocket messages are reassembled by mitmproxy and the
Fragmented WebSocket messages are reassembled by mitmproxy and then
represented as a single instance of this class.
The [WebSocket RFC](https://tools.ietf.org/html/rfc6455) specifies both
text and binary messages. To avoid a whole class of nasty type confusion bugs,
mitmproxy stores all message contents as binary. If you need text, you can decode the `content` property:
mitmproxy stores all message contents as `bytes`. If you need a `str`, you can access the `text` property
on text messages:
>>> from wsproto.frame_protocol import Opcode
>>> if message.type == Opcode.TEXT:
>>> text = message.content.decode()
Per the WebSocket spec, text messages always use UTF-8 encoding.
>>> if message.is_text:
>>> text = message.text
"""
from_client: bool
@ -40,8 +38,7 @@ class WebSocketMessage(serializable.Serializable):
"""
The message type, as per RFC 6455's [opcode](https://tools.ietf.org/html/rfc6455#section-5.2).
Note that mitmproxy will always store the message contents as *bytes*.
A dedicated `.text` property for text messages is planned, see https://github.com/mitmproxy/mitmproxy/pull/4486.
Mitmproxy currently only exposes messages assembled from `TEXT` and `BINARY` frames.
"""
content: bytes
"""A byte-string representing the content of this message."""
@ -81,10 +78,39 @@ class WebSocketMessage(serializable.Serializable):
else:
return repr(self.content)
@property
def is_text(self) -> bool:
"""
`True` if this message is assembled from WebSocket `TEXT` frames,
`False` if it is assembled from `BINARY` frames.
"""
return self.type == Opcode.TEXT
def kill(self):
# Likely to be replaced with .drop() in the future, see https://github.com/mitmproxy/mitmproxy/pull/4486
self.killed = True
@property
def text(self) -> str:
"""
The message content as text.
This attribute is only available if `WebSocketMessage.is_text` is `True`.
*See also:* `WebSocketMessage.content`
"""
if self.type != Opcode.TEXT:
raise AttributeError(f"{self.type.name.title()} WebSocket frames do not have a 'text' attribute.")
return self.content.decode()
@text.setter
def text(self, value: str) -> None:
if self.type != Opcode.TEXT:
raise AttributeError(f"{self.type.name.title()} WebSocket frames do not have a 'text' attribute.")
self.content = value.encode()
class WebSocketData(stateobject.StateObject):
"""
@ -97,9 +123,9 @@ class WebSocketData(stateobject.StateObject):
closed_by_client: Optional[bool] = None
"""
True if the client closed the connection,
False if the server closed the connection,
None if the connection is active.
`True` if the client closed the connection,
`False` if the server closed the connection,
`None` if the connection is active.
"""
close_code: Optional[int] = None
"""[Close Code](https://tools.ietf.org/html/rfc6455#section-7.1.5)"""

View File

@ -90,7 +90,7 @@ async def test_start_stop():
@pytest.mark.asyncio
async def test_inject():
async def test_inject() -> None:
async def server_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
while s := await reader.read(1):
writer.write(s.upper())
@ -112,39 +112,39 @@ async def test_inject():
writer.write(b"a")
assert await reader.read(1) == b"A"
ps.inject_tcp(state.flows[0], False, "b")
ps.inject_tcp(state.flows[0], False, b"b")
assert await reader.read(1) == b"B"
ps.inject_tcp(state.flows[0], True, "c")
ps.inject_tcp(state.flows[0], True, b"c")
assert await reader.read(1) == b"c"
@pytest.mark.asyncio
async def test_inject_fail():
async def test_inject_fail() -> None:
ps = Proxyserver()
with taddons.context(ps) as tctx:
ps.inject_websocket(
tflow.tflow(),
True,
"test"
b"test"
)
await tctx.master.await_log("Cannot inject WebSocket messages into non-WebSocket flows.", level="warn")
ps.inject_tcp(
tflow.tflow(),
True,
"test"
b"test"
)
await tctx.master.await_log("Cannot inject TCP messages into non-TCP flows.", level="warn")
ps.inject_websocket(
tflow.twebsocketflow(),
True,
"test"
b"test"
)
await tctx.master.await_log("Flow is not from a live connection.", level="warn")
ps.inject_websocket(
tflow.ttcpflow(),
True,
"test"
b"test"
)
await tctx.master.await_log("Flow is not from a live connection.", level="warn")

View File

@ -367,24 +367,6 @@ class TestCommand:
],
[],
],
[
r'cmd13 "a \"b\" c"',
[
command.ParseResult(value="cmd13", type=mitmproxy.types.Cmd, valid=False),
command.ParseResult(value=" ", type=mitmproxy.types.Space, valid=True),
command.ParseResult(value=r'"a \"b\" c"', type=mitmproxy.types.Unknown, valid=False),
],
[],
],
[
r"cmd14 'a \'b\' c'",
[
command.ParseResult(value="cmd14", type=mitmproxy.types.Cmd, valid=False),
command.ParseResult(value=" ", type=mitmproxy.types.Space, valid=True),
command.ParseResult(value=r"'a \'b\' c'", type=mitmproxy.types.Unknown, valid=False),
],
[],
],
[
" spaces_at_the_begining_are_not_stripped",
[
@ -436,12 +418,6 @@ def test_simple():
c.call("nonexistent")
with pytest.raises(exceptions.CommandError, match="Unknown"):
c.execute("\\")
with pytest.raises(exceptions.CommandError, match="Unknown"):
c.execute(r"\'")
with pytest.raises(exceptions.CommandError, match="Unknown"):
c.execute(r"\"")
with pytest.raises(exceptions.CommandError, match="Unknown"):
c.execute(r"\"")
c.add("empty", a.empty)
c.execute("empty")

View File

@ -11,7 +11,6 @@ from mitmproxy import command_lexer
("'foo'", True),
('"foo"', True),
("'foo' bar'", False),
("'foo\\' bar'", True),
("'foo' 'bar'", False),
("'foo'x", False),
('''"foo ''', True),
@ -43,8 +42,19 @@ def test_expr(test_input, expected):
@given(text())
@example(r"foo")
@example(r"'foo\''")
@example(r"'foo\"'")
@example(r'"foo\""')
@example(r'"foo\'"')
@example("'foo\\'")
@example("'foo\\\\'")
@example("\"foo\\'\"")
@example("\"foo\\\\'\"")
@example('\'foo\\"\'')
@example(r"\\\foo")
def test_quote_unquote_cycle(s):
assert command_lexer.unquote(command_lexer.quote(s)) == s
assert command_lexer.unquote(command_lexer.quote(s)).replace(r"\x22", '"') == s
@given(text())

View File

@ -40,6 +40,21 @@ def test_str():
assert b.is_valid(tctx.master.commands, str, 1) is False
assert b.completion(tctx.master.commands, str, "") == []
assert b.parse(tctx.master.commands, str, "foo") == "foo"
assert b.parse(tctx.master.commands, str, r"foo\nbar") == "foo\nbar"
assert b.parse(tctx.master.commands, str, r"\N{BELL}") == "🔔"
with pytest.raises(mitmproxy.exceptions.TypeError):
b.parse(tctx.master.commands, bool, r"\N{UNKNOWN UNICODE SYMBOL!}")
def test_bytes():
with taddons.context() as tctx:
b = mitmproxy.types._BytesType()
assert b.is_valid(tctx.master.commands, bytes, b"foo") is True
assert b.is_valid(tctx.master.commands, bytes, 1) is False
assert b.completion(tctx.master.commands, bytes, "") == []
assert b.parse(tctx.master.commands, bytes, "foo") == b"foo"
with pytest.raises(mitmproxy.exceptions.TypeError):
b.parse(tctx.master.commands, bytes, "incomplete escape sequence\\")
def test_unknown():

View File

@ -1,3 +1,5 @@
import pytest
from mitmproxy import http
from mitmproxy import websocket
from mitmproxy.test import tflow
@ -26,3 +28,18 @@ class TestWebSocketMessage:
assert not m.killed
m.kill()
assert m.killed
def test_text(self):
txt = websocket.WebSocketMessage(Opcode.TEXT, True, b"foo")
bin = websocket.WebSocketMessage(Opcode.BINARY, True, b"foo")
assert txt.is_text
assert txt.text == "foo"
txt.text = "bar"
assert txt.content == b"bar"
assert not bin.is_text
with pytest.raises(AttributeError, match="do not have a 'text' attribute."):
_ = bin.text
with pytest.raises(AttributeError, match="do not have a 'text' attribute."):
bin.text = "bar"

View File

@ -20,18 +20,18 @@ async def test_commands_exist():
await m.load_flow(tflow())
for binding in km.bindings:
parsed, _ = command_manager.parse_partial(binding.command.strip())
cmd = parsed[0].value
args = [
a.value for a in parsed[1:]
if a.type != mitmproxy.types.Space
]
assert cmd in m.commands.commands
cmd_obj = m.commands.commands[cmd]
try:
parsed, _ = command_manager.parse_partial(binding.command.strip())
cmd = parsed[0].value
args = [
a.value for a in parsed[1:]
if a.type != mitmproxy.types.Space
]
assert cmd in m.commands.commands
cmd_obj = m.commands.commands[cmd]
cmd_obj.prepare_args(args)
except Exception as e:
raise ValueError(f"Invalid command: {binding.command}") from e
raise ValueError(f"Invalid binding: {binding.command}") from e