mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 10:16:27 +00:00
[sans-io] websockets: fix bugs, 100% test coverage 🎉
This commit is contained in:
parent
e79cc6bc24
commit
d32a5d5f33
@ -297,7 +297,12 @@ class HttpStream(layer.Layer):
|
|||||||
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
|
||||||
|
|
||||||
if self.flow.response.status_code == 101:
|
if self.flow.response.status_code == 101:
|
||||||
if self.flow.response.headers.get("upgrade", "").strip().lower() == "websocket":
|
is_websocket = (
|
||||||
|
self.flow.response.headers.get("upgrade", "").lower() == "websocket"
|
||||||
|
and
|
||||||
|
self.flow.request.headers.get("Sec-WebSocket-Version", "") == "13"
|
||||||
|
)
|
||||||
|
if is_websocket:
|
||||||
self.child_layer = websocket.WebsocketLayer(self.context, self.flow)
|
self.child_layer = websocket.WebsocketLayer(self.context, self.flow)
|
||||||
else:
|
else:
|
||||||
self.child_layer = tcp.TCPLayer(self.context)
|
self.child_layer = tcp.TCPLayer(self.context)
|
||||||
|
@ -69,6 +69,9 @@ class WebsocketConnection(wsproto.Connection):
|
|||||||
data = super().send(event)
|
data = super().send(event)
|
||||||
return commands.SendData(self.conn, data)
|
return commands.SendData(self.conn, data)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"WebsocketConnection<{self.state.name}, {self.conn}>"
|
||||||
|
|
||||||
|
|
||||||
class WebsocketLayer(layer.Layer):
|
class WebsocketLayer(layer.Layer):
|
||||||
"""
|
"""
|
||||||
@ -92,7 +95,7 @@ class WebsocketLayer(layer.Layer):
|
|||||||
# Parse extension headers. We only support deflate at the moment and ignore everything else.
|
# Parse extension headers. We only support deflate at the moment and ignore everything else.
|
||||||
ext_header = self.flow.handshake_flow.response.headers.get("Sec-WebSocket-Extensions", "")
|
ext_header = self.flow.handshake_flow.response.headers.get("Sec-WebSocket-Extensions", "")
|
||||||
if ext_header:
|
if ext_header:
|
||||||
for ext in wsproto.utilities.split_comma_header(ext_header):
|
for ext in wsproto.utilities.split_comma_header(ext_header.encode("ascii", "replace")):
|
||||||
ext_name = ext.split(";", 1)[0].strip()
|
ext_name = ext.split(";", 1)[0].strip()
|
||||||
if ext_name == wsproto.extensions.PerMessageDeflate.name:
|
if ext_name == wsproto.extensions.PerMessageDeflate.name:
|
||||||
client_deflate = wsproto.extensions.PerMessageDeflate()
|
client_deflate = wsproto.extensions.PerMessageDeflate()
|
||||||
@ -109,7 +112,7 @@ class WebsocketLayer(layer.Layer):
|
|||||||
|
|
||||||
yield WebsocketStartHook(self.flow)
|
yield WebsocketStartHook(self.flow)
|
||||||
|
|
||||||
if self.flow.stream:
|
if self.flow.stream: # pragma: no cover
|
||||||
raise NotImplementedError("WebSocket streaming is not supported at the moment.")
|
raise NotImplementedError("WebSocket streaming is not supported at the moment.")
|
||||||
|
|
||||||
self._handle_event = self.relay_messages
|
self._handle_event = self.relay_messages
|
||||||
@ -130,10 +133,8 @@ class WebsocketLayer(layer.Layer):
|
|||||||
if isinstance(event, events.DataReceived):
|
if isinstance(event, events.DataReceived):
|
||||||
src_ws.receive_data(event.data)
|
src_ws.receive_data(event.data)
|
||||||
elif isinstance(event, events.ConnectionClosed):
|
elif isinstance(event, events.ConnectionClosed):
|
||||||
if src_ws.state not in {ConnectionState.OPEN, ConnectionState.LOCAL_CLOSING}:
|
|
||||||
return
|
|
||||||
src_ws.receive_data(None)
|
src_ws.receive_data(None)
|
||||||
else:
|
else: # pragma: no cover
|
||||||
raise AssertionError(f"Unexpected event: {event}")
|
raise AssertionError(f"Unexpected event: {event}")
|
||||||
|
|
||||||
for ws_event in src_ws.events():
|
for ws_event in src_ws.events():
|
||||||
@ -143,14 +144,10 @@ class WebsocketLayer(layer.Layer):
|
|||||||
if ws_event.message_finished:
|
if ws_event.message_finished:
|
||||||
if isinstance(ws_event, wsproto.events.TextMessage):
|
if isinstance(ws_event, wsproto.events.TextMessage):
|
||||||
frame_type = Opcode.TEXT
|
frame_type = Opcode.TEXT
|
||||||
content = ""
|
content = "".join(src_ws.frame_buf)
|
||||||
else:
|
else:
|
||||||
frame_type = Opcode.BINARY
|
frame_type = Opcode.BINARY
|
||||||
content = b""
|
content = b"".join(src_ws.frame_buf)
|
||||||
try:
|
|
||||||
content = content.join(src_ws.frame_buf)
|
|
||||||
except TypeError:
|
|
||||||
return self.handle_protocol_error(src_ws, "mixed text and binary fragments")
|
|
||||||
|
|
||||||
fragmentizer = Fragmentizer(src_ws.frame_buf)
|
fragmentizer = Fragmentizer(src_ws.frame_buf)
|
||||||
src_ws.frame_buf.clear()
|
src_ws.frame_buf.clear()
|
||||||
@ -166,8 +163,8 @@ class WebsocketLayer(layer.Layer):
|
|||||||
|
|
||||||
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
|
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
|
||||||
yield commands.Log(
|
yield commands.Log(
|
||||||
f"Received WebSocket {event.__class__.__name__.lower()} from {from_str} "
|
f"Received WebSocket {ws_event.__class__.__name__.lower()} from {from_str} "
|
||||||
f"(payload: {ws_event.payload!r})"
|
f"(payload: {bytes(ws_event.payload)!r})"
|
||||||
)
|
)
|
||||||
yield dst_ws.send(ws_event)
|
yield dst_ws.send(ws_event)
|
||||||
elif isinstance(ws_event, wsproto.events.CloseConnection):
|
elif isinstance(ws_event, wsproto.events.CloseConnection):
|
||||||
@ -175,7 +172,7 @@ class WebsocketLayer(layer.Layer):
|
|||||||
self.flow.close_code = ws_event.code
|
self.flow.close_code = ws_event.code
|
||||||
self.flow.close_reason = ws_event.reason
|
self.flow.close_reason = ws_event.reason
|
||||||
|
|
||||||
for ws in [self.client_ws, self.server_ws]:
|
for ws in [self.server_ws, self.client_ws]:
|
||||||
if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
|
if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
|
||||||
# response == original event, so no need to differentiate here.
|
# response == original event, so no need to differentiate here.
|
||||||
yield ws.send(ws_event)
|
yield ws.send(ws_event)
|
||||||
@ -185,17 +182,9 @@ class WebsocketLayer(layer.Layer):
|
|||||||
else:
|
else:
|
||||||
self.flow.error = flow.Error(f"WebSocket Error: {format_close_event(ws_event)}")
|
self.flow.error = flow.Error(f"WebSocket Error: {format_close_event(ws_event)}")
|
||||||
yield WebsocketErrorHook(self.flow)
|
yield WebsocketErrorHook(self.flow)
|
||||||
yield commands.CloseConnection(self.context.client)
|
|
||||||
else:
|
|
||||||
raise AssertionError(f"Unexpected WebSocket event: {ws_event}")
|
|
||||||
|
|
||||||
def handle_protocol_error(self, ws: WebsocketConnection, message=None):
|
|
||||||
self.flow.error = flow.Error(f"WebSocket Error: {human.format_address(ws.conn.peername)} {message}")
|
|
||||||
yield WebsocketErrorHook(self.flow)
|
|
||||||
if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
|
|
||||||
yield ws.send(wsproto.events.CloseConnection(CloseReason.PROTOCOL_ERROR, message))
|
|
||||||
yield commands.CloseConnection(self.context.client)
|
|
||||||
self._handle_event = self.done
|
self._handle_event = self.done
|
||||||
|
else: # pragma: no cover
|
||||||
|
raise AssertionError(f"Unexpected WebSocket event: {ws_event}")
|
||||||
|
|
||||||
@expect(events.DataReceived, events.ConnectionClosed)
|
@expect(events.DataReceived, events.ConnectionClosed)
|
||||||
def done(self, _) -> layer.CommandGenerator[None]:
|
def done(self, _) -> layer.CommandGenerator[None]:
|
||||||
@ -219,7 +208,7 @@ class Fragmentizer:
|
|||||||
meaning. An intermediary might coalesce and/or split frames, [...]
|
meaning. An intermediary might coalesce and/or split frames, [...]
|
||||||
|
|
||||||
Practice:
|
Practice:
|
||||||
Some WebSocket servers reject large payload sizes. ¯\_(ツ)_/¯
|
Some WebSocket servers reject large payload sizes.
|
||||||
|
|
||||||
As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks.
|
As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks.
|
||||||
"""
|
"""
|
||||||
|
@ -0,0 +1,334 @@
|
|||||||
|
import secrets
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import wsproto
|
||||||
|
import wsproto.events
|
||||||
|
from mitmproxy.http import HTTPFlow
|
||||||
|
from mitmproxy.net.http import Request, Response
|
||||||
|
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||||
|
from mitmproxy.proxy2.commands import SendData, CloseConnection, Log
|
||||||
|
from mitmproxy.proxy2.context import Server, ConnectionState
|
||||||
|
from mitmproxy.proxy2.events import DataReceived, ConnectionClosed
|
||||||
|
from mitmproxy.proxy2.layers import http, websocket
|
||||||
|
from mitmproxy.websocket import WebSocketFlow
|
||||||
|
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Masked:
|
||||||
|
unmasked: bytes
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
other = bytearray(other)
|
||||||
|
assert other[1] & 0b1000_0000 # assert this is actually masked
|
||||||
|
other[1] &= 0b0111_1111 # remove mask bit
|
||||||
|
assert other[1] < 126 # (we don't support extended payload length here)
|
||||||
|
mask = other[2:6]
|
||||||
|
payload = bytes([x ^ mask[i % 4] for i, x in enumerate(other[6:])])
|
||||||
|
return self.unmasked == other[:2] + payload
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
def masked(unmasked: bytes) -> bytes:
|
||||||
|
return _Masked(unmasked) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def masked_bytes(unmasked: bytes) -> bytes:
|
||||||
|
header = bytearray(unmasked[:2])
|
||||||
|
assert header[1] < 126 # assert that this is neither masked nor extended payload
|
||||||
|
header[1] |= 0b1000_0000
|
||||||
|
mask = secrets.token_bytes(4)
|
||||||
|
masked = bytes([x ^ mask[i % 4] for i, x in enumerate(unmasked[2:])])
|
||||||
|
return bytes(header + mask + masked)
|
||||||
|
|
||||||
|
|
||||||
|
def test_masking():
|
||||||
|
m = masked(b"\x02\x03foo")
|
||||||
|
assert m == b"\x02\x83\x1c\x96\xd4\rz\xf9\xbb"
|
||||||
|
assert m == masked_bytes(b"\x02\x03foo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade(tctx):
|
||||||
|
"""Test a HTTP -> WebSocket upgrade"""
|
||||||
|
tctx.server.address = ("example.com", 80)
|
||||||
|
tctx.server.state = ConnectionState.OPEN
|
||||||
|
http_flow = Placeholder(HTTPFlow)
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
assert (
|
||||||
|
Playbook(http.HttpLayer(tctx, HTTPMode.transparent))
|
||||||
|
>> DataReceived(tctx.client,
|
||||||
|
b"GET / HTTP/1.1\r\n"
|
||||||
|
b"Connection: upgrade\r\n"
|
||||||
|
b"Upgrade: websocket\r\n"
|
||||||
|
b"Sec-WebSocket-Version: 13\r\n"
|
||||||
|
b"\r\n")
|
||||||
|
<< http.HttpRequestHeadersHook(http_flow)
|
||||||
|
>> reply()
|
||||||
|
<< http.HttpRequestHook(http_flow)
|
||||||
|
>> reply()
|
||||||
|
<< SendData(tctx.server, b"GET / HTTP/1.1\r\n"
|
||||||
|
b"Connection: upgrade\r\n"
|
||||||
|
b"Upgrade: websocket\r\n"
|
||||||
|
b"Sec-WebSocket-Version: 13\r\n"
|
||||||
|
b"\r\n")
|
||||||
|
>> DataReceived(tctx.server, b"HTTP/1.1 101 Switching Protocols\r\n"
|
||||||
|
b"Upgrade: websocket\r\n"
|
||||||
|
b"Connection: Upgrade\r\n"
|
||||||
|
b"\r\n")
|
||||||
|
<< http.HttpResponseHeadersHook(http_flow)
|
||||||
|
>> reply()
|
||||||
|
<< http.HttpResponseHook(http_flow)
|
||||||
|
>> reply()
|
||||||
|
<< SendData(tctx.client, b"HTTP/1.1 101 Switching Protocols\r\n"
|
||||||
|
b"Upgrade: websocket\r\n"
|
||||||
|
b"Connection: Upgrade\r\n"
|
||||||
|
b"\r\n")
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> DataReceived(tctx.client, masked_bytes(b"\x81\x0bhello world"))
|
||||||
|
<< websocket.WebsocketMessageHook(flow)
|
||||||
|
>> reply()
|
||||||
|
<< SendData(tctx.server, masked(b"\x81\x0bhello world"))
|
||||||
|
>> DataReceived(tctx.server, b"\x82\nhello back")
|
||||||
|
<< websocket.WebsocketMessageHook(flow)
|
||||||
|
>> reply()
|
||||||
|
<< SendData(tctx.client, b"\x82\nhello back")
|
||||||
|
)
|
||||||
|
assert flow().handshake_flow == http_flow()
|
||||||
|
assert len(flow().messages) == 2
|
||||||
|
assert flow().messages[0].content == "hello world"
|
||||||
|
assert flow().messages[0].from_client
|
||||||
|
assert flow().messages[1].content == b"hello back"
|
||||||
|
assert flow().messages[1].from_client is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def ws_testdata(tctx):
|
||||||
|
tctx.server.address = ("example.com", 80)
|
||||||
|
tctx.server.state = ConnectionState.OPEN
|
||||||
|
flow = HTTPFlow(
|
||||||
|
tctx.client,
|
||||||
|
tctx.server
|
||||||
|
)
|
||||||
|
flow.request = Request.make("GET", "http://example.com/", headers={
|
||||||
|
"Connection": "upgrade",
|
||||||
|
"Upgrade": "websocket",
|
||||||
|
"Sec-WebSocket-Version": "13",
|
||||||
|
})
|
||||||
|
flow.response = Response.make(101, headers={
|
||||||
|
"Connection": "upgrade",
|
||||||
|
"Upgrade": "websocket",
|
||||||
|
})
|
||||||
|
return tctx, Playbook(websocket.WebsocketLayer(tctx, flow))
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_message(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> DataReceived(tctx.server, b"\x81\x03foo")
|
||||||
|
<< websocket.WebsocketMessageHook(flow)
|
||||||
|
)
|
||||||
|
flow().messages[-1].content = flow().messages[-1].content.replace("foo", "foobar")
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> reply()
|
||||||
|
<< SendData(tctx.client, b"\x81\x06foobar")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_message(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> DataReceived(tctx.server, b"\x81\x03foo")
|
||||||
|
<< websocket.WebsocketMessageHook(flow)
|
||||||
|
)
|
||||||
|
flow().messages[-1].content = ""
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> reply()
|
||||||
|
<< None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fragmented(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> DataReceived(tctx.server, b"\x01\x03foo")
|
||||||
|
>> DataReceived(tctx.server, b"\x80\x03bar")
|
||||||
|
<< websocket.WebsocketMessageHook(flow)
|
||||||
|
>> reply()
|
||||||
|
<< SendData(tctx.client, b"\x01\x03foo")
|
||||||
|
<< SendData(tctx.client, b"\x80\x03bar")
|
||||||
|
)
|
||||||
|
assert flow().messages[-1].content == "foobar"
|
||||||
|
|
||||||
|
|
||||||
|
def test_protocol_error(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> DataReceived(tctx.server, b"\x01\x03foo")
|
||||||
|
>> DataReceived(tctx.server, b"\x02\x03bar")
|
||||||
|
<< SendData(tctx.server, masked(b"\x88/\x03\xeaexpected CONTINUATION, got <Opcode.BINARY: 2>"))
|
||||||
|
<< CloseConnection(tctx.server)
|
||||||
|
<< SendData(tctx.client, b"\x88/\x03\xeaexpected CONTINUATION, got <Opcode.BINARY: 2>")
|
||||||
|
<< CloseConnection(tctx.client)
|
||||||
|
<< websocket.WebsocketErrorHook(flow)
|
||||||
|
>> reply()
|
||||||
|
|
||||||
|
)
|
||||||
|
assert not flow().messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_ping(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> DataReceived(tctx.client, masked_bytes(b"\x89\x11ping-with-payload"))
|
||||||
|
<< Log("Received WebSocket ping from client (payload: b'ping-with-payload')")
|
||||||
|
<< SendData(tctx.server, masked(b"\x89\x11ping-with-payload"))
|
||||||
|
>> DataReceived(tctx.server, b"\x8a\x11pong-with-payload")
|
||||||
|
<< Log("Received WebSocket pong from server (payload: b'pong-with-payload')")
|
||||||
|
<< SendData(tctx.client, b"\x8a\x11pong-with-payload")
|
||||||
|
)
|
||||||
|
assert not flow().messages
|
||||||
|
|
||||||
|
|
||||||
|
def test_close_normal(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
masked_close = Placeholder(bytes)
|
||||||
|
close = Placeholder(bytes)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> DataReceived(tctx.client, masked_bytes(b"\x88\x00"))
|
||||||
|
<< SendData(tctx.server, masked_close)
|
||||||
|
<< CloseConnection(tctx.server)
|
||||||
|
<< SendData(tctx.client, close)
|
||||||
|
<< CloseConnection(tctx.client)
|
||||||
|
<< websocket.WebsocketEndHook(flow)
|
||||||
|
>> reply()
|
||||||
|
)
|
||||||
|
# wsproto currently handles this inconsistently, see
|
||||||
|
# https://github.com/python-hyper/wsproto/pull/153/files
|
||||||
|
assert masked_close() == masked(b"\x88\x02\x03\xe8") or masked_close() == masked(b"\x88\x00")
|
||||||
|
assert close() == b"\x88\x02\x03\xe8" or close() == b"\x88\x00"
|
||||||
|
|
||||||
|
assert flow().close_code == 1005
|
||||||
|
|
||||||
|
|
||||||
|
def test_close_disconnect(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> ConnectionClosed(tctx.server)
|
||||||
|
<< CloseConnection(tctx.server)
|
||||||
|
<< SendData(tctx.client, b"\x88\x02\x03\xe8")
|
||||||
|
<< CloseConnection(tctx.client)
|
||||||
|
<< websocket.WebsocketErrorHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> ConnectionClosed(tctx.client)
|
||||||
|
)
|
||||||
|
assert "ABNORMAL_CLOSURE" in flow().error.msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_close_error(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
>> DataReceived(tctx.server, b"\x88\x02\x0f\xa0")
|
||||||
|
<< SendData(tctx.server, masked(b"\x88\x02\x0f\xa0"))
|
||||||
|
<< CloseConnection(tctx.server)
|
||||||
|
<< SendData(tctx.client, b"\x88\x02\x0f\xa0")
|
||||||
|
<< CloseConnection(tctx.client)
|
||||||
|
<< websocket.WebsocketErrorHook(flow)
|
||||||
|
>> reply()
|
||||||
|
)
|
||||||
|
assert "UNKNOWN_ERROR=4000" in flow().error.msg
|
||||||
|
|
||||||
|
|
||||||
|
def test_deflate(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
http_flow: HTTPFlow = playbook.layer.flow.handshake_flow
|
||||||
|
http_flow.response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate; server_max_window_bits=10"
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
# https://tools.ietf.org/html/rfc7692#section-7.2.3.1
|
||||||
|
>> DataReceived(tctx.server, bytes.fromhex("c1 07 f2 48 cd c9 c9 07 00"))
|
||||||
|
<< websocket.WebsocketMessageHook(flow)
|
||||||
|
>> reply()
|
||||||
|
<< SendData(tctx.client, bytes.fromhex("c1 07 f2 48 cd c9 c9 07 00"))
|
||||||
|
)
|
||||||
|
assert flow().messages[0].content == "Hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_ext(ws_testdata):
|
||||||
|
tctx, playbook = ws_testdata
|
||||||
|
flow = Placeholder(WebSocketFlow)
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
|
http_flow: HTTPFlow = playbook.layer.flow.handshake_flow
|
||||||
|
http_flow.response.headers["Sec-WebSocket-Extensions"] = "funky-bits; param=42"
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
<< Log("Ignoring unknown WebSocket extension 'funky-bits'.")
|
||||||
|
<< websocket.WebsocketStartHook(flow)
|
||||||
|
>> reply()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_connection_repr(tctx):
|
||||||
|
ws = websocket.WebsocketConnection(wsproto.ConnectionType.SERVER, conn=tctx.client)
|
||||||
|
assert repr(ws)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFragmentizer:
|
||||||
|
def test_empty(self):
|
||||||
|
f = websocket.Fragmentizer([b"foo"])
|
||||||
|
assert list(f(b"")) == []
|
||||||
|
|
||||||
|
def test_keep_sizes(self):
|
||||||
|
f = websocket.Fragmentizer([b"foo", b"bar"])
|
||||||
|
assert list(f(b"foobaz")) == [
|
||||||
|
wsproto.events.Message(b"foo", message_finished=False),
|
||||||
|
wsproto.events.Message(b"baz", message_finished=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_rechunk(self):
|
||||||
|
f = websocket.Fragmentizer([b"foo"])
|
||||||
|
f.FRAGMENT_SIZE = 4
|
||||||
|
assert list(f(b"foobar")) == [
|
||||||
|
wsproto.events.Message(b"foob", message_finished=False),
|
||||||
|
wsproto.events.Message(b"ar", message_finished=True),
|
||||||
|
]
|
@ -72,7 +72,7 @@ def _merge_sends(lst: typing.List[commands.Command], ignore_hooks: bool, ignore_
|
|||||||
current_send = None
|
current_send = None
|
||||||
for x in lst:
|
for x in lst:
|
||||||
if isinstance(x, commands.SendData):
|
if isinstance(x, commands.SendData):
|
||||||
if current_send is None:
|
if current_send is None or current_send.connection != x.connection:
|
||||||
current_send = x
|
current_send = x
|
||||||
yield x
|
yield x
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user