ressurect killed, is_text -> type

This commit is contained in:
Maximilian Hils 2021-03-10 20:23:25 +01:00
parent e1f938f05b
commit d8aeef1bfd
8 changed files with 58 additions and 62 deletions

View File

@ -284,22 +284,21 @@ class Dumper:
) )
def websocket_message(self, f: http.HTTPFlow): def websocket_message(self, f: http.HTTPFlow):
assert f.websocket is not None assert f.websocket is not None # satisfy type checker
if self.match(f): if self.match(f):
message = f.websocket.messages[-1] message = f.websocket.messages[-1]
direction = "->" if message.from_client else "<-" direction = "->" if message.from_client else "<-"
typ = "text" if message.is_text else "binary"
self.echo( self.echo(
f"{human.format_address(f.client_conn.peername)} " f"{human.format_address(f.client_conn.peername)} "
f"{direction} WebSocket {typ} message " f"{direction} WebSocket {message.type.name.lower()} message "
f"{direction} {human.format_address(f.server_conn.address)}{f.request.path}" f"{direction} {human.format_address(f.server_conn.address)}{f.request.path}"
) )
if ctx.options.flow_detail >= 3: if ctx.options.flow_detail >= 3:
self._echo_message(message, f) self._echo_message(message, f)
def websocket_end(self, f: http.HTTPFlow): def websocket_end(self, f: http.HTTPFlow):
assert f.websocket is not None assert f.websocket is not None # satisfy type checker
if self.match(f): if self.match(f):
c = 'client' if f.websocket.close_by_client else 'server' c = 'client' if f.websocket.close_by_client else 'server'
self.echo(f"WebSocket connection closed by {c}: {f.websocket.close_code} {f.websocket.close_reason}") self.echo(f"WebSocket connection closed by {c}: {f.websocket.close_code} {f.websocket.close_reason}")

View File

@ -102,6 +102,8 @@ class Save:
self.active_flows.add(flow) self.active_flows.add(flow)
def response(self, flow: http.HTTPFlow): def response(self, flow: http.HTTPFlow):
# websocket flows will receive either websocket_end or websocket_error,
# we don't want to persist them here already
if self.stream and flow.websocket is None: if self.stream and flow.websocket is None:
self.stream.add(flow) self.stream.add(flow)
self.active_flows.discard(flow) self.active_flows.discard(flow)

View File

@ -291,12 +291,7 @@ def convert_11_12(data):
"and may appear duplicated." "and may appear duplicated."
) )
data["websocket"] = { data["websocket"] = {
"messages": [ "messages": ws_flow["messages"],
# old: int(self.type), self.from_client, self.content, self.timestamp, self.killed
# new: self.from_client, self.is_text, self.content, self.timestamp
[from_client, typ == 0x1, strutils.always_bytes(content) if not killed else b"", timestamp]
for typ, from_client, content, timestamp, killed in ws_flow["messages"]
],
"close_by_client": ws_flow["close_sender"] == "client", "close_by_client": ws_flow["close_sender"] == "client",
"close_code": ws_flow["close_code"], "close_code": ws_flow["close_code"],
"close_reason": ws_flow["close_reason"], "close_reason": ws_flow["close_reason"],

View File

@ -11,7 +11,7 @@ from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.context import Context from mitmproxy.proxy.context import Context
from mitmproxy.proxy.utils import expect from mitmproxy.proxy.utils import expect
from wsproto import ConnectionState from wsproto import ConnectionState
from wsproto.frame_protocol import CloseReason from wsproto.frame_protocol import CloseReason, Opcode
@dataclass @dataclass
@ -96,7 +96,7 @@ class WebsocketLayer(layer.Layer):
server_extensions = [] server_extensions = []
# Parse extension headers. We only support deflate at the moment and ignore everything else. # Parse extension headers. We only support deflate at the moment and ignore everything else.
assert self.flow.response assert self.flow.response # satisfy type checker
ext_header = self.flow.response.headers.get("Sec-WebSocket-Extensions", "") ext_header = self.flow.response.headers.get("Sec-WebSocket-Extensions", "")
if ext_header: if ext_header:
for ext in wsproto.utilities.split_comma_header(ext_header.encode("ascii", "replace")): for ext in wsproto.utilities.split_comma_header(ext_header.encode("ascii", "replace")):
@ -122,7 +122,7 @@ class WebsocketLayer(layer.Layer):
@expect(events.DataReceived, events.ConnectionClosed) @expect(events.DataReceived, events.ConnectionClosed)
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]: def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
assert self.flow.websocket assert self.flow.websocket # satisfy type checker
from_client = event.connection == self.context.client from_client = event.connection == self.context.client
from_str = 'client' if from_client else 'server' from_str = 'client' if from_client else 'server'
@ -144,8 +144,10 @@ class WebsocketLayer(layer.Layer):
if isinstance(ws_event, wsproto.events.Message): if isinstance(ws_event, wsproto.events.Message):
is_text = isinstance(ws_event.data, str) is_text = isinstance(ws_event.data, str)
if is_text: if is_text:
typ = Opcode.TEXT
src_ws.frame_buf.append(ws_event.data.encode()) src_ws.frame_buf.append(ws_event.data.encode())
else: else:
typ = Opcode.BINARY
src_ws.frame_buf.append(ws_event.data) src_ws.frame_buf.append(ws_event.data)
if ws_event.message_finished: if ws_event.message_finished:
@ -154,12 +156,13 @@ class WebsocketLayer(layer.Layer):
fragmentizer = Fragmentizer(src_ws.frame_buf, is_text) fragmentizer = Fragmentizer(src_ws.frame_buf, is_text)
src_ws.frame_buf.clear() src_ws.frame_buf.clear()
message = websocket.WebSocketMessage(from_client, is_text, content) message = websocket.WebSocketMessage(typ, from_client, content)
self.flow.websocket.messages.append(message) self.flow.websocket.messages.append(message)
yield WebsocketMessageHook(self.flow) yield WebsocketMessageHook(self.flow)
for msg in fragmentizer(message.content): if not message.killed:
yield dst_ws.send2(msg) for msg in fragmentizer(message.content):
yield dst_ws.send2(msg)
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)): elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
yield commands.Log( yield commands.Log(

View File

@ -7,6 +7,7 @@ from mitmproxy import http
from mitmproxy import tcp from mitmproxy import tcp
from mitmproxy import websocket from mitmproxy import websocket
from mitmproxy.test import tutils from mitmproxy.test import tutils
from wsproto.frame_protocol import Opcode
def ttcpflow(client_conn=True, server_conn=True, messages=True, err=None): def ttcpflow(client_conn=True, server_conn=True, messages=True, err=None):
@ -69,9 +70,9 @@ def twebsocketflow(messages=True, err=None) -> http.HTTPFlow:
if messages is True: if messages is True:
flow.websocket.messages = [ flow.websocket.messages = [
websocket.WebSocketMessage(True, False, b"hello binary", 946681203), websocket.WebSocketMessage(Opcode.BINARY, True, b"hello binary", 946681203),
websocket.WebSocketMessage(True, True, b"hello text", 946681204), websocket.WebSocketMessage(Opcode.TEXT, True, b"hello text", 946681204),
websocket.WebSocketMessage(False, True, b"it's me", 946681205), websocket.WebSocketMessage(Opcode.TEXT, False, b"it's me", 946681205),
] ]
if err is True: if err is True:
flow.error = terr() flow.error = terr()

View File

@ -6,12 +6,14 @@ as HTTP flows as well. They can be distinguished from regular HTTP requests by h
This module only defines the classes for individual `WebSocketMessage`s and the `WebSocketData` container. This module only defines the classes for individual `WebSocketMessage`s and the `WebSocketData` container.
""" """
import time import time
import warnings from typing import List, Tuple, Union
from typing import List
from typing import Optional from typing import Optional
from mitmproxy import stateobject from mitmproxy import stateobject
from mitmproxy.coretypes import serializable from mitmproxy.coretypes import serializable
from wsproto.frame_protocol import Opcode
WebSocketMessageState = Tuple[int, bool, bytes, float, bool]
class WebSocketMessage(serializable.Serializable): class WebSocketMessage(serializable.Serializable):
@ -25,75 +27,63 @@ class WebSocketMessage(serializable.Serializable):
text and binary messages. To avoid a whole class of nasty type confusion bugs, text and binary messages. To avoid a whole class of nasty type confusion bugs,
mitmproxy stores all message contents as binary. If you need text, you can decode the `content` property: mitmproxy stores all message contents as binary. If you need text, you can decode the `content` property:
>>> if message.is_text: >>> from wsproto.frame_protocol import Opcode
>>> if message.type == Opcode.TEXT:
>>> text = message.content.decode() >>> text = message.content.decode()
Per the WebSocket spec, text messages always use UTF-8 encoding.
""" """
from_client: bool from_client: bool
"""True if this messages was sent by the client.""" """True if this messages was sent by the client."""
is_text: bool type: Opcode
""" """
True if the message is a text message, False if the message is a binary message. The message type, as per RFC 6455's [opcode](https://tools.ietf.org/html/rfc6455#section-5.2).
In either case, mitmproxy will store the message contents as *bytes*. Note that mitmproxy will always store the message contents as *bytes*.
A dedicated `.text` property for text messages is planned, see https://github.com/mitmproxy/mitmproxy/pull/4486.
""" """
content: bytes content: bytes
"""A byte-string representing the content of this message.""" """A byte-string representing the content of this message."""
timestamp: float timestamp: float
"""Timestamp of when this message was received or created.""" """Timestamp of when this message was received or created."""
killed: bool
"""True if the message has not been forwarded by mitmproxy, False otherwise."""
def __init__( def __init__(
self, self,
type: Union[int, Opcode],
from_client: bool, from_client: bool,
is_text: bool,
content: bytes, content: bytes,
timestamp: Optional[float] = None, timestamp: Optional[float] = None,
killed: bool = False,
) -> None: ) -> None:
self.from_client = from_client self.from_client = from_client
self.is_text = is_text self.type = Opcode(type)
self.content = content self.content = content
self.timestamp: float = timestamp or time.time() self.timestamp: float = timestamp or time.time()
self.killed = killed
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state: WebSocketMessageState):
return cls(*state) return cls(*state)
def get_state(self): def get_state(self) -> WebSocketMessageState:
return self.from_client, self.is_text, self.content, self.timestamp return int(self.type), self.from_client, self.content, self.timestamp, self.killed
def set_state(self, state): def set_state(self, state: WebSocketMessageState) -> None:
self.from_client, self.is_text, self.content, self.timestamp = state typ, self.from_client, self.content, self.timestamp, self.killed = state
self.type = Opcode(typ)
def __repr__(self): def __repr__(self):
if self.is_text: if self.type == Opcode.TEXT:
return repr(self.content.decode(errors="replace")) return repr(self.content.decode(errors="replace"))
else: else:
return repr(self.content) return repr(self.content)
def kill(self): # pragma: no cover def kill(self):
""" # Likely to be replaced with .drop() in the future, see https://github.com/mitmproxy/mitmproxy/pull/4486
Kill this message. self.killed = True
It will not be sent to the other endpoint.
"""
warnings.warn(
"WebSocketMessage.kill is deprecated, set an empty content instead.",
DeprecationWarning,
stacklevel=2,
)
self.content = b""
@property
def killed(self) -> bool: # pragma: no cover
"""
True if this messages was killed and should not be sent to the other endpoint.
"""
warnings.warn(
"WebSocketMessage.killed is deprecated, check for an empty content instead.",
DeprecationWarning,
stacklevel=2,
)
return bool(self.content)
class WebSocketData(stateobject.StateObject): class WebSocketData(stateobject.StateObject):

View File

@ -13,6 +13,7 @@ from mitmproxy.proxy.events import DataReceived, ConnectionClosed
from mitmproxy.proxy.layers import http, websocket from mitmproxy.proxy.layers import http, websocket
from mitmproxy.websocket import WebSocketData from mitmproxy.websocket import WebSocketData
from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply
from wsproto.frame_protocol import Opcode
@dataclass @dataclass
@ -97,10 +98,10 @@ def test_upgrade(tctx):
assert len(flow().websocket.messages) == 2 assert len(flow().websocket.messages) == 2
assert flow().websocket.messages[0].content == b"hello world" assert flow().websocket.messages[0].content == b"hello world"
assert flow().websocket.messages[0].from_client assert flow().websocket.messages[0].from_client
assert flow().websocket.messages[0].is_text assert flow().websocket.messages[0].type == Opcode.TEXT
assert flow().websocket.messages[1].content == b"hello back" assert flow().websocket.messages[1].content == b"hello back"
assert flow().websocket.messages[1].from_client is False assert flow().websocket.messages[1].from_client is False
assert flow().websocket.messages[1].is_text is False assert flow().websocket.messages[1].type == Opcode.BINARY
@pytest.fixture() @pytest.fixture()
@ -150,7 +151,7 @@ def test_drop_message(ws_testdata):
>> DataReceived(tctx.server, b"\x81\x03foo") >> DataReceived(tctx.server, b"\x81\x03foo")
<< websocket.WebsocketMessageHook(flow) << websocket.WebsocketMessageHook(flow)
) )
flow.websocket.messages[-1].content = "" flow.websocket.messages[-1].kill()
assert ( assert (
playbook playbook
>> reply() >> reply()

View File

@ -1,6 +1,7 @@
from mitmproxy import http from mitmproxy import http
from mitmproxy import websocket from mitmproxy import websocket
from mitmproxy.test import tflow from mitmproxy.test import tflow
from wsproto.frame_protocol import Opcode
class TestWebSocketData: class TestWebSocketData:
@ -15,9 +16,13 @@ class TestWebSocketData:
class TestWebSocketMessage: class TestWebSocketMessage:
def test_basic(self): def test_basic(self):
m = websocket.WebSocketMessage(True, True, b"foo") m = websocket.WebSocketMessage(Opcode.TEXT, True, b"foo")
m.set_state(m.get_state()) m.set_state(m.get_state())
assert m.content == b"foo" assert m.content == b"foo"
assert repr(m) == "'foo'" assert repr(m) == "'foo'"
m.is_text = False m.type = Opcode.BINARY
assert repr(m) == "b'foo'" assert repr(m) == "b'foo'"
assert not m.killed
m.kill()
assert m.killed