refactor websockets

This commit is contained in:
Thomas Kriechbaumer 2018-06-10 19:38:09 +02:00 committed by Maximilian Hils
parent 34f3573be5
commit bc20b77c48
3 changed files with 259 additions and 118 deletions

View File

@ -1,10 +1,13 @@
import wsproto
from wsproto import events as wsevents
from wsproto.connection import ConnectionType, WSConnection
from wsproto.extensions import PerMessageDeflate
from mitmproxy import websocket, http, flow
from mitmproxy.proxy2 import events, commands
from mitmproxy.proxy2.context import Context
from mitmproxy.proxy2.layer import Layer
from mitmproxy.proxy2.utils import expect
from wsproto import connection as wsconn
from wsproto import events as wsevents
class WebsocketLayer(Layer):
@ -17,18 +20,40 @@ class WebsocketLayer(Layer):
def __init__(self, context: Context, handshake_flow: http.HTTPFlow):
super().__init__(context)
self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow)
assert context.server.connected
self.flow.metadata['websocket_handshake'] = handshake_flow.id
self.handshake_flow = handshake_flow
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
self.client_frame_buffer = []
self.server_frame_buffer = []
extension = self.flow.server_extensions
self.client_conn = wsconn.WSConnection(wsconn.SERVER, wsconn.ConnectionState.OPEN,
extensions=[extension] if extension else None)
self.server_conn = wsconn.WSConnection(wsconn.CLIENT, wsconn.ConnectionState.OPEN,
extensions=[extension] if extension else None)
assert context.server.connected
@expect(events.Start)
def start(self, _) -> commands.TCommandGenerator:
extensions = []
if 'Sec-WebSocket-Extensions' in self.handshake_flow.response.headers:
if PerMessageDeflate.name in self.handshake_flow.response.headers['Sec-WebSocket-Extensions']:
extensions = [PerMessageDeflate()]
self.client_conn = WSConnection(ConnectionType.SERVER,
extensions=extensions)
self.server_conn = WSConnection(ConnectionType.CLIENT,
host=self.handshake_flow.request.host,
resource=self.handshake_flow.request.path,
extensions=extensions)
if extensions:
self.client_conn.extensions[0].finalize(self.client_conn, self.handshake_flow.response.headers['Sec-WebSocket-Extensions'])
self.server_conn.extensions[0].finalize(self.server_conn, self.handshake_flow.response.headers['Sec-WebSocket-Extensions'])
data = self.server_conn.bytes_to_send()
self.client_conn.receive_bytes(data)
event = next(self.client_conn.events())
assert isinstance(event, wsevents.ConnectionRequested)
self.client_conn.accept(event)
self.server_conn.receive_bytes(self.client_conn.bytes_to_send())
assert isinstance(next(self.server_conn.events()), wsevents.ConnectionEstablished)
yield commands.Hook("websocket_start", self.flow)
self._handle_event = self.relay_messages
@ -51,55 +76,33 @@ class WebsocketLayer(Layer):
source.receive_bytes(event.data)
for ws_event in source.events():
closing = False
received_ws_events = list(source.events())
for ws_event in received_ws_events:
if isinstance(ws_event, wsevents.DataReceived):
fb.append(ws_event.data)
if ws_event.message_finished:
if isinstance(ws_event, wsevents.BytesReceived):
payload = b"".join(fb)
else:
payload = "".join(fb)
fb.clear()
websocket_message = websocket.WebSocketMessage(0x1 if isinstance(ws_event, wsevents.TextReceived) else 0x2,
from_client, payload)
self.flow.messages.append(websocket_message)
yield commands.Hook("websocket_message", self.flow)
other.send_data(ws_event.data, ws_event.message_finished)
yield commands.SendData(send_to, other.bytes_to_send())
yield from self._handle_data_received(ws_event, source, other, send_to, from_client, fb)
elif isinstance(ws_event, wsevents.PingReceived):
yield commands.Log(
"info",
"Websocket PING received {}".format(ws_event.payload.decode())
)
other.ping()
yield commands.SendData(send_to, other.bytes_to_send())
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
yield from self._handle_ping_received(ws_event, source, other, send_to, from_client)
elif isinstance(ws_event, wsevents.PongReceived):
yield from self._handle_pong_received(ws_event, source, other, send_to, from_client)
elif isinstance(ws_event, wsevents.ConnectionClosed):
yield from self._handle_connection_closed(ws_event, source, other, send_to, from_client)
closing = True
else:
yield commands.Log(
"info",
"Websocket PONG received {}".format(ws_event.payload.decode())
"WebSocket unhandled event: from {}: {}".format("client" if from_client else "server", ws_event)
)
elif isinstance(ws_event, wsevents.ConnectionClosed):
other.close(ws_event.code, ws_event.reason)
yield commands.SendData(send_to, other.bytes_to_send())
# FIXME: Wait for other end to actually send the closing frame
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
if ws_event.code != 1000:
self.flow.error = flow.Error(
"WebSocket connection closed unexpectedly by {}: {}".format(
"client" if from_client else "server",
ws_event.reason
)
)
yield commands.Hook("websocket_error", self.flow)
if closing:
yield commands.Hook("websocket_end", self.flow)
if not from_client:
yield commands.CloseConnection(self.context.client)
self._handle_event = self.done
# TODO: elif isinstance(event, events.InjectMessage):
# TODO: come up with a solid API to inject messages
elif isinstance(event, events.ConnectionClosed):
yield commands.Log("error", "Connection closed abnormally")
self.flow.error = flow.Error(
@ -116,3 +119,87 @@ class WebsocketLayer(Layer):
@expect(events.DataReceived, events.ConnectionClosed)
def done(self, _):
yield from ()
def _handle_data_received(self, ws_event, source, other, send_to, from_client, fb):
fb.append(ws_event.data)
if ws_event.message_finished:
original_chunk_sizes = [len(f) for f in fb]
if isinstance(ws_event, wsevents.TextReceived):
message_type = wsproto.frame_protocol.Opcode.TEXT
payload = ''.join(fb)
else:
message_type = wsproto.frame_protocol.Opcode.BINARY
payload = b''.join(fb)
fb.clear()
websocket_message = websocket.WebSocketMessage(message_type, from_client, payload)
length = len(websocket_message.content)
self.flow.messages.append(websocket_message)
yield commands.Hook("websocket_message", self.flow)
if not self.flow.stream and not websocket_message.killed:
def get_chunk(payload):
if len(payload) == length:
# message has the same length, we can reuse the same sizes
pos = 0
for s in original_chunk_sizes:
yield (payload[pos:pos + s], True if pos + s == length else False)
pos += s
else:
# just re-chunk everything into 4kB frames
# header len = 4 bytes without masking key and 8 bytes with masking key
chunk_size = 4088 if from_client else 4092
chunks = range(0, len(payload), chunk_size)
for i in chunks:
yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
for chunk, final in get_chunk(websocket_message.content):
other.send_data(chunk, final)
yield commands.SendData(send_to, other.bytes_to_send())
if self.flow.stream:
other.send_data(ws_event.data, ws_event.message_finished)
yield commands.SendData(send_to, other.bytes_to_send())
def _handle_ping_received(self, ws_event, source, other, send_to, from_client):
yield commands.Log(
"info",
"WebSocket PING received from {}: {}".format("client" if from_client else "server",
ws_event.payload.decode() or "<no payload>")
)
# We do not forward the PING payload, as this might leak information!
other.ping()
yield commands.SendData(send_to, other.bytes_to_send())
# PING is automatically answered with a PONG by wsproto
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
def _handle_pong_received(self, ws_event, source, other, send_to, from_client):
yield commands.Log(
"info",
"WebSocket PONG received from {}: {}".format("client" if from_client else "server",
ws_event.payload.decode() or "<no payload>")
)
def _handle_connection_closed(self, ws_event, source, other, send_to, from_client):
self.flow.close_sender = "client" if from_client else "server"
self.flow.close_code = ws_event.code
self.flow.close_reason = ws_event.reason
other.close(ws_event.code, ws_event.reason)
yield commands.SendData(send_to, other.bytes_to_send())
# FIXME: Wait for other end to actually send the closing frame
# FIXME: https://github.com/python-hyper/wsproto/pull/50
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
if ws_event.code != 1000:
self.flow.error = flow.Error(
"WebSocket connection closed unexpectedly by {}: {}".format(
"client" if from_client else "server",
ws_event.reason
)
)
yield commands.Hook("websocket_error", self.flow)

View File

@ -1,7 +1,9 @@
import struct
from unittest import mock
import pytest
from mitmproxy.net.websockets import Frame, OPCODE
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.layers import websocket
from mitmproxy.test import tflow
@ -15,7 +17,8 @@ def ws_playbook(tctx):
websocket.WebsocketLayer(
tctx,
tflow.twebsocketflow().handshake_flow
)
),
ignore_log=False,
)
with mock.patch("os.urandom") as m:
m.return_value = b"\x10\x11\x12\x13"
@ -25,24 +28,32 @@ def ws_playbook(tctx):
def test_simple(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
bytes(Frame(fin=1, mask=1, opcode=OPCODE.TEXT, payload=b'client-foobar')),
bytes(Frame(fin=1, opcode=OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')),
bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
bytes(Frame(fin=1, opcode=OPCODE.TEXT, payload=b'fail')),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, b"\x82\x85\x10\x11\x12\x13Xt~\x7f\x7f") # Frame with payload b"Hello"
>> events.DataReceived(tctx.client, frames[0])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.server, b"\x82\x85\x10\x11\x12\x13Xt~\x7f\x7f")
>> events.DataReceived(tctx.server, b'\x81\x05Hello') # Frame with payload "Hello"
<< commands.SendData(tctx.server, frames[0])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.client, b'\x81\x05Hello')
>> events.DataReceived(tctx.client, b'\x88\x82\x10\x11\x12\x13\x13\xf9') # Closing frame
<< commands.SendData(tctx.server, b'\x88\x82\x10\x11\x12\x13\x13\xf9')
<< commands.SendData(tctx.client, b'\x88\x02\x03\xe8')
<< commands.SendData(tctx.client, frames[1])
>> events.DataReceived(tctx.client, frames[2])
<< commands.SendData(tctx.server, frames[2])
<< commands.SendData(tctx.client, frames[3])
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, b'\x81\x05Hello')
>> events.DataReceived(tctx.server, frames[4])
<< None
)
@ -52,78 +63,24 @@ def test_simple(tctx, ws_playbook):
def test_server_close(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, b'\x88\x02\x03\xe8')
<< commands.SendData(tctx.client, b'\x88\x02\x03\xe8')
<< commands.SendData(tctx.server, b'\x88\x82\x10\x11\x12\x13\x13\xf9')
>> events.DataReceived(tctx.server, frames[0])
<< commands.SendData(tctx.client, frames[0])
<< commands.SendData(tctx.server, frames[1])
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
<< commands.CloseConnection(tctx.client)
)
def test_ping_pong(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, b'\x89\x80\x10\x11\x12\x13') # Ping
<< commands.Log("info", "Websocket PING received ")
<< commands.SendData(tctx.server, b'\x89\x80\x10\x11\x12\x13')
<< commands.SendData(tctx.client, b'\x8a\x00')
>> events.DataReceived(tctx.server, b'\x8a\x00') # Pong
<< commands.Log("info", "Websocket PONG received ")
)
def test_connection_failed(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, b"Not a valid frame")
<< commands.SendData(tctx.server, b'\x88\x94\x10\x11\x12\x13\x13\xfb[}fp~zt1}cs~vv0!jv')
<< commands.SendData(tctx.client, b'\x88\x14\x03\xeaInvalid opcode 0xe')
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
)
def test_extension(tctx):
f = tutils.Placeholder()
tctx.server.connected = True
handshake_flow = tflow.twebsocketflow().handshake_flow
handshake_flow.request.headers["sec-websocket-extensions"] = "permessage-deflate;"
handshake_flow.response.headers["sec-websocket-extensions"] = "permessage-deflate;"
playbook = tutils.playbook(websocket.WebsocketLayer(tctx, handshake_flow))
with mock.patch("os.urandom") as m:
m.return_value = b"\x10\x11\x12\x13"
assert (
playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, b'\xc1\x87\x10\x11\x12\x13\xe2Y\xdf\xda\xd9\x16\x12') # Compressed Frame
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.server, b'\xc1\x87\x10\x11\x12\x13\xe2Y\xdf\xda\xd9\x16\x12')
>> events.DataReceived(tctx.server, b'\xc1\x07\xf2H\xcd\xc9\xc9\x07\x00') # Compressed Frame
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.client, b'\xc1\x07\xf2H\xcd\xc9\xc9\x07\x00')
)
assert len(f().messages) == 2
assert f().messages[0].content == "Hello"
assert f().messages[1].content == "Hello"
def test_connection_closed(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
@ -140,3 +97,100 @@ def test_connection_closed(tctx, ws_playbook):
)
assert f().error
def test_connection_failed(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
b'Not a valid frame',
bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1002) + b'Invalid opcode 0xe')),
bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1002) + b'Invalid opcode 0xe')),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.SendData(tctx.server, frames[1])
<< commands.SendData(tctx.client, frames[2])
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
)
def test_ping_pong(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
bytes(Frame(fin=1, mask=1, opcode=OPCODE.PING)),
bytes(Frame(fin=1, opcode=OPCODE.PONG)),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Log("info", "WebSocket PING received from client: <no payload>")
<< commands.SendData(tctx.server, frames[0])
<< commands.SendData(tctx.client, frames[1])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Log("info", "WebSocket PONG received from server: <no payload>")
)
def test_ping_pong_hidden_payload(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
bytes(Frame(fin=1, opcode=OPCODE.PING, payload=b'foobar')),
bytes(Frame(fin=1, opcode=OPCODE.PING, payload=b'')),
bytes(Frame(fin=1, mask=1, opcode=OPCODE.PONG, payload=b'foobar')),
bytes(Frame(fin=1, mask=1, opcode=OPCODE.PONG, payload=b'')),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[0])
<< commands.Log("info", "WebSocket PING received from server: foobar")
<< commands.SendData(tctx.client, frames[1])
<< commands.SendData(tctx.server, frames[2])
>> events.DataReceived(tctx.client, frames[3])
<< commands.Log("info", "WebSocket PONG received from client: <no payload>")
)
def test_extension(tctx, ws_playbook):
f = tutils.Placeholder()
ws_playbook.layer.handshake_flow.request.headers["sec-websocket-extensions"] = "permessage-deflate;"
ws_playbook.layer.handshake_flow.response.headers["sec-websocket-extensions"] = "permessage-deflate;"
frames = [
bytes(Frame(fin=1, mask=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x48\xcd\xc9\xc9\x07\x00')),
bytes(Frame(fin=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x48\xcd\xc9\xc9\x07\x00')),
bytes(Frame(fin=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x00\x11\x00\x00')),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.server, frames[0])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.client, frames[2])
)
assert len(f().messages) == 2
assert f().messages[0].content == "Hello"
assert f().messages[1].content == "Hello"

View File

@ -150,7 +150,7 @@ class playbook(typing.Generic[T]):
self._errored = True
def _str(x):
arrow = ">" if isinstance(x, events.Event) else "<"
arrow = ">>" if isinstance(x, events.Event) else "<<"
x = str(x) \
.replace('Placeholder:None', '<unset placeholder>') \
.replace('Placeholder:', '')