[sans-io] websockets: fix bugs, 100% test coverage 🎉

This commit is contained in:
Maximilian Hils 2020-12-09 15:51:40 +01:00
parent e79cc6bc24
commit d32a5d5f33
4 changed files with 356 additions and 28 deletions

View File

@ -297,7 +297,12 @@ class HttpStream(layer.Layer):
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client) yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
if self.flow.response.status_code == 101: if self.flow.response.status_code == 101:
if self.flow.response.headers.get("upgrade", "").strip().lower() == "websocket": is_websocket = (
self.flow.response.headers.get("upgrade", "").lower() == "websocket"
and
self.flow.request.headers.get("Sec-WebSocket-Version", "") == "13"
)
if is_websocket:
self.child_layer = websocket.WebsocketLayer(self.context, self.flow) self.child_layer = websocket.WebsocketLayer(self.context, self.flow)
else: else:
self.child_layer = tcp.TCPLayer(self.context) self.child_layer = tcp.TCPLayer(self.context)

View File

@ -69,6 +69,9 @@ class WebsocketConnection(wsproto.Connection):
data = super().send(event) data = super().send(event)
return commands.SendData(self.conn, data) return commands.SendData(self.conn, data)
def __repr__(self):
return f"WebsocketConnection<{self.state.name}, {self.conn}>"
class WebsocketLayer(layer.Layer): class WebsocketLayer(layer.Layer):
""" """
@ -92,7 +95,7 @@ class WebsocketLayer(layer.Layer):
# Parse extension headers. We only support deflate at the moment and ignore everything else. # Parse extension headers. We only support deflate at the moment and ignore everything else.
ext_header = self.flow.handshake_flow.response.headers.get("Sec-WebSocket-Extensions", "") ext_header = self.flow.handshake_flow.response.headers.get("Sec-WebSocket-Extensions", "")
if ext_header: if ext_header:
for ext in wsproto.utilities.split_comma_header(ext_header): for ext in wsproto.utilities.split_comma_header(ext_header.encode("ascii", "replace")):
ext_name = ext.split(";", 1)[0].strip() ext_name = ext.split(";", 1)[0].strip()
if ext_name == wsproto.extensions.PerMessageDeflate.name: if ext_name == wsproto.extensions.PerMessageDeflate.name:
client_deflate = wsproto.extensions.PerMessageDeflate() client_deflate = wsproto.extensions.PerMessageDeflate()
@ -109,7 +112,7 @@ class WebsocketLayer(layer.Layer):
yield WebsocketStartHook(self.flow) yield WebsocketStartHook(self.flow)
if self.flow.stream: if self.flow.stream: # pragma: no cover
raise NotImplementedError("WebSocket streaming is not supported at the moment.") raise NotImplementedError("WebSocket streaming is not supported at the moment.")
self._handle_event = self.relay_messages self._handle_event = self.relay_messages
@ -130,10 +133,8 @@ class WebsocketLayer(layer.Layer):
if isinstance(event, events.DataReceived): if isinstance(event, events.DataReceived):
src_ws.receive_data(event.data) src_ws.receive_data(event.data)
elif isinstance(event, events.ConnectionClosed): elif isinstance(event, events.ConnectionClosed):
if src_ws.state not in {ConnectionState.OPEN, ConnectionState.LOCAL_CLOSING}:
return
src_ws.receive_data(None) src_ws.receive_data(None)
else: else: # pragma: no cover
raise AssertionError(f"Unexpected event: {event}") raise AssertionError(f"Unexpected event: {event}")
for ws_event in src_ws.events(): for ws_event in src_ws.events():
@ -143,14 +144,10 @@ class WebsocketLayer(layer.Layer):
if ws_event.message_finished: if ws_event.message_finished:
if isinstance(ws_event, wsproto.events.TextMessage): if isinstance(ws_event, wsproto.events.TextMessage):
frame_type = Opcode.TEXT frame_type = Opcode.TEXT
content = "" content = "".join(src_ws.frame_buf)
else: else:
frame_type = Opcode.BINARY frame_type = Opcode.BINARY
content = b"" content = b"".join(src_ws.frame_buf)
try:
content = content.join(src_ws.frame_buf)
except TypeError:
return self.handle_protocol_error(src_ws, "mixed text and binary fragments")
fragmentizer = Fragmentizer(src_ws.frame_buf) fragmentizer = Fragmentizer(src_ws.frame_buf)
src_ws.frame_buf.clear() src_ws.frame_buf.clear()
@ -166,8 +163,8 @@ class WebsocketLayer(layer.Layer):
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)): elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
yield commands.Log( yield commands.Log(
f"Received WebSocket {event.__class__.__name__.lower()} from {from_str} " f"Received WebSocket {ws_event.__class__.__name__.lower()} from {from_str} "
f"(payload: {ws_event.payload!r})" f"(payload: {bytes(ws_event.payload)!r})"
) )
yield dst_ws.send(ws_event) yield dst_ws.send(ws_event)
elif isinstance(ws_event, wsproto.events.CloseConnection): elif isinstance(ws_event, wsproto.events.CloseConnection):
@ -175,28 +172,20 @@ class WebsocketLayer(layer.Layer):
self.flow.close_code = ws_event.code self.flow.close_code = ws_event.code
self.flow.close_reason = ws_event.reason self.flow.close_reason = ws_event.reason
for ws in [self.client_ws, self.server_ws]: for ws in [self.server_ws, self.client_ws]:
if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}: if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
# response == original event, so no need to differentiate here. # response == original event, so no need to differentiate here.
yield ws.send(ws_event) yield ws.send(ws_event)
yield commands.CloseConnection(ws.conn) yield commands.CloseConnection(ws.conn)
if ws_event.code in {1000, 1001, 1005}: if ws_event.code in {1000, 1001, 1005}:
yield WebsocketEndHook(self.flow) yield WebsocketEndHook(self.flow)
else: else:
self.flow.error = flow.Error(f"WebSocket Error: {format_close_event(ws_event)}") self.flow.error = flow.Error(f"WebSocket Error: {format_close_event(ws_event)}")
yield WebsocketErrorHook(self.flow) yield WebsocketErrorHook(self.flow)
yield commands.CloseConnection(self.context.client) self._handle_event = self.done
else: else: # pragma: no cover
raise AssertionError(f"Unexpected WebSocket event: {ws_event}") raise AssertionError(f"Unexpected WebSocket event: {ws_event}")
def handle_protocol_error(self, ws: WebsocketConnection, message=None):
self.flow.error = flow.Error(f"WebSocket Error: {human.format_address(ws.conn.peername)} {message}")
yield WebsocketErrorHook(self.flow)
if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
yield ws.send(wsproto.events.CloseConnection(CloseReason.PROTOCOL_ERROR, message))
yield commands.CloseConnection(self.context.client)
self._handle_event = self.done
@expect(events.DataReceived, events.ConnectionClosed) @expect(events.DataReceived, events.ConnectionClosed)
def done(self, _) -> layer.CommandGenerator[None]: def done(self, _) -> layer.CommandGenerator[None]:
yield from () yield from ()
@ -219,7 +208,7 @@ class Fragmentizer:
meaning. An intermediary might coalesce and/or split frames, [...] meaning. An intermediary might coalesce and/or split frames, [...]
Practice: Practice:
Some WebSocket servers reject large payload sizes. ¯\_()_/¯ Some WebSocket servers reject large payload sizes.
As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks. As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks.
""" """

View File

@ -0,0 +1,334 @@
import secrets
from dataclasses import dataclass
import pytest
import wsproto
import wsproto.events
from mitmproxy.http import HTTPFlow
from mitmproxy.net.http import Request, Response
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2.commands import SendData, CloseConnection, Log
from mitmproxy.proxy2.context import Server, ConnectionState
from mitmproxy.proxy2.events import DataReceived, ConnectionClosed
from mitmproxy.proxy2.layers import http, websocket
from mitmproxy.websocket import WebSocketFlow
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply
@dataclass
class _Masked:
unmasked: bytes
def __eq__(self, other):
other = bytearray(other)
assert other[1] & 0b1000_0000 # assert this is actually masked
other[1] &= 0b0111_1111 # remove mask bit
assert other[1] < 126 # (we don't support extended payload length here)
mask = other[2:6]
payload = bytes([x ^ mask[i % 4] for i, x in enumerate(other[6:])])
return self.unmasked == other[:2] + payload
# noinspection PyTypeChecker
def masked(unmasked: bytes) -> bytes:
return _Masked(unmasked) # type: ignore
def masked_bytes(unmasked: bytes) -> bytes:
header = bytearray(unmasked[:2])
assert header[1] < 126 # assert that this is neither masked nor extended payload
header[1] |= 0b1000_0000
mask = secrets.token_bytes(4)
masked = bytes([x ^ mask[i % 4] for i, x in enumerate(unmasked[2:])])
return bytes(header + mask + masked)
def test_masking():
m = masked(b"\x02\x03foo")
assert m == b"\x02\x83\x1c\x96\xd4\rz\xf9\xbb"
assert m == masked_bytes(b"\x02\x03foo")
def test_upgrade(tctx):
"""Test a HTTP -> WebSocket upgrade"""
tctx.server.address = ("example.com", 80)
tctx.server.state = ConnectionState.OPEN
http_flow = Placeholder(HTTPFlow)
flow = Placeholder(WebSocketFlow)
assert (
Playbook(http.HttpLayer(tctx, HTTPMode.transparent))
>> DataReceived(tctx.client,
b"GET / HTTP/1.1\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: websocket\r\n"
b"Sec-WebSocket-Version: 13\r\n"
b"\r\n")
<< http.HttpRequestHeadersHook(http_flow)
>> reply()
<< http.HttpRequestHook(http_flow)
>> reply()
<< SendData(tctx.server, b"GET / HTTP/1.1\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: websocket\r\n"
b"Sec-WebSocket-Version: 13\r\n"
b"\r\n")
>> DataReceived(tctx.server, b"HTTP/1.1 101 Switching Protocols\r\n"
b"Upgrade: websocket\r\n"
b"Connection: Upgrade\r\n"
b"\r\n")
<< http.HttpResponseHeadersHook(http_flow)
>> reply()
<< http.HttpResponseHook(http_flow)
>> reply()
<< SendData(tctx.client, b"HTTP/1.1 101 Switching Protocols\r\n"
b"Upgrade: websocket\r\n"
b"Connection: Upgrade\r\n"
b"\r\n")
<< websocket.WebsocketStartHook(flow)
>> reply()
>> DataReceived(tctx.client, masked_bytes(b"\x81\x0bhello world"))
<< websocket.WebsocketMessageHook(flow)
>> reply()
<< SendData(tctx.server, masked(b"\x81\x0bhello world"))
>> DataReceived(tctx.server, b"\x82\nhello back")
<< websocket.WebsocketMessageHook(flow)
>> reply()
<< SendData(tctx.client, b"\x82\nhello back")
)
assert flow().handshake_flow == http_flow()
assert len(flow().messages) == 2
assert flow().messages[0].content == "hello world"
assert flow().messages[0].from_client
assert flow().messages[1].content == b"hello back"
assert flow().messages[1].from_client is False
@pytest.fixture()
def ws_testdata(tctx):
tctx.server.address = ("example.com", 80)
tctx.server.state = ConnectionState.OPEN
flow = HTTPFlow(
tctx.client,
tctx.server
)
flow.request = Request.make("GET", "http://example.com/", headers={
"Connection": "upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
})
flow.response = Response.make(101, headers={
"Connection": "upgrade",
"Upgrade": "websocket",
})
return tctx, Playbook(websocket.WebsocketLayer(tctx, flow))
def test_modify_message(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
>> DataReceived(tctx.server, b"\x81\x03foo")
<< websocket.WebsocketMessageHook(flow)
)
flow().messages[-1].content = flow().messages[-1].content.replace("foo", "foobar")
assert (
playbook
>> reply()
<< SendData(tctx.client, b"\x81\x06foobar")
)
def test_drop_message(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
>> DataReceived(tctx.server, b"\x81\x03foo")
<< websocket.WebsocketMessageHook(flow)
)
flow().messages[-1].content = ""
assert (
playbook
>> reply()
<< None
)
def test_fragmented(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
>> DataReceived(tctx.server, b"\x01\x03foo")
>> DataReceived(tctx.server, b"\x80\x03bar")
<< websocket.WebsocketMessageHook(flow)
>> reply()
<< SendData(tctx.client, b"\x01\x03foo")
<< SendData(tctx.client, b"\x80\x03bar")
)
assert flow().messages[-1].content == "foobar"
def test_protocol_error(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
>> DataReceived(tctx.server, b"\x01\x03foo")
>> DataReceived(tctx.server, b"\x02\x03bar")
<< SendData(tctx.server, masked(b"\x88/\x03\xeaexpected CONTINUATION, got <Opcode.BINARY: 2>"))
<< CloseConnection(tctx.server)
<< SendData(tctx.client, b"\x88/\x03\xeaexpected CONTINUATION, got <Opcode.BINARY: 2>")
<< CloseConnection(tctx.client)
<< websocket.WebsocketErrorHook(flow)
>> reply()
)
assert not flow().messages
def test_ping(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
>> DataReceived(tctx.client, masked_bytes(b"\x89\x11ping-with-payload"))
<< Log("Received WebSocket ping from client (payload: b'ping-with-payload')")
<< SendData(tctx.server, masked(b"\x89\x11ping-with-payload"))
>> DataReceived(tctx.server, b"\x8a\x11pong-with-payload")
<< Log("Received WebSocket pong from server (payload: b'pong-with-payload')")
<< SendData(tctx.client, b"\x8a\x11pong-with-payload")
)
assert not flow().messages
def test_close_normal(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
masked_close = Placeholder(bytes)
close = Placeholder(bytes)
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
>> DataReceived(tctx.client, masked_bytes(b"\x88\x00"))
<< SendData(tctx.server, masked_close)
<< CloseConnection(tctx.server)
<< SendData(tctx.client, close)
<< CloseConnection(tctx.client)
<< websocket.WebsocketEndHook(flow)
>> reply()
)
# wsproto currently handles this inconsistently, see
# https://github.com/python-hyper/wsproto/pull/153/files
assert masked_close() == masked(b"\x88\x02\x03\xe8") or masked_close() == masked(b"\x88\x00")
assert close() == b"\x88\x02\x03\xe8" or close() == b"\x88\x00"
assert flow().close_code == 1005
def test_close_disconnect(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
>> ConnectionClosed(tctx.server)
<< CloseConnection(tctx.server)
<< SendData(tctx.client, b"\x88\x02\x03\xe8")
<< CloseConnection(tctx.client)
<< websocket.WebsocketErrorHook(flow)
>> reply()
>> ConnectionClosed(tctx.client)
)
assert "ABNORMAL_CLOSURE" in flow().error.msg
def test_close_error(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
>> DataReceived(tctx.server, b"\x88\x02\x0f\xa0")
<< SendData(tctx.server, masked(b"\x88\x02\x0f\xa0"))
<< CloseConnection(tctx.server)
<< SendData(tctx.client, b"\x88\x02\x0f\xa0")
<< CloseConnection(tctx.client)
<< websocket.WebsocketErrorHook(flow)
>> reply()
)
assert "UNKNOWN_ERROR=4000" in flow().error.msg
def test_deflate(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
# noinspection PyUnresolvedReferences
http_flow: HTTPFlow = playbook.layer.flow.handshake_flow
http_flow.response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate; server_max_window_bits=10"
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
# https://tools.ietf.org/html/rfc7692#section-7.2.3.1
>> DataReceived(tctx.server, bytes.fromhex("c1 07 f2 48 cd c9 c9 07 00"))
<< websocket.WebsocketMessageHook(flow)
>> reply()
<< SendData(tctx.client, bytes.fromhex("c1 07 f2 48 cd c9 c9 07 00"))
)
assert flow().messages[0].content == "Hello"
def test_unknown_ext(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
# noinspection PyUnresolvedReferences
http_flow: HTTPFlow = playbook.layer.flow.handshake_flow
http_flow.response.headers["Sec-WebSocket-Extensions"] = "funky-bits; param=42"
assert (
playbook
<< Log("Ignoring unknown WebSocket extension 'funky-bits'.")
<< websocket.WebsocketStartHook(flow)
>> reply()
)
def test_websocket_connection_repr(tctx):
ws = websocket.WebsocketConnection(wsproto.ConnectionType.SERVER, conn=tctx.client)
assert repr(ws)
class TestFragmentizer:
def test_empty(self):
f = websocket.Fragmentizer([b"foo"])
assert list(f(b"")) == []
def test_keep_sizes(self):
f = websocket.Fragmentizer([b"foo", b"bar"])
assert list(f(b"foobaz")) == [
wsproto.events.Message(b"foo", message_finished=False),
wsproto.events.Message(b"baz", message_finished=True),
]
def test_rechunk(self):
f = websocket.Fragmentizer([b"foo"])
f.FRAGMENT_SIZE = 4
assert list(f(b"foobar")) == [
wsproto.events.Message(b"foob", message_finished=False),
wsproto.events.Message(b"ar", message_finished=True),
]

View File

@ -72,7 +72,7 @@ def _merge_sends(lst: typing.List[commands.Command], ignore_hooks: bool, ignore_
current_send = None current_send = None
for x in lst: for x in lst:
if isinstance(x, commands.SendData): if isinstance(x, commands.SendData):
if current_send is None: if current_send is None or current_send.connection != x.connection:
current_send = x current_send = x
yield x yield x
else: else: