prepare WebSocket stack to move to wsproto

This commit is contained in:
Thomas Kriechbaumer 2017-08-12 14:06:10 +02:00
parent 8e9194c2b4
commit 130021b76d
2 changed files with 146 additions and 109 deletions

View File

@ -3,9 +3,14 @@ import socket
import struct
from OpenSSL import SSL
from wsproto import events
from wsproto.connection import ConnectionType, WSConnection
from wsproto.extensions import PerMessageDeflate
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
from mitmproxy.net import http
from mitmproxy.net import tcp
from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
@ -44,108 +49,139 @@ class WebSocketLayer(base.Layer):
self.client_frame_buffer = []
self.server_frame_buffer = []
def _handle_frame(self, frame, source_conn, other_conn, is_server):
if frame.header.opcode & 0x8 == 0:
return self._handle_data_frame(frame, source_conn, other_conn, is_server)
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
elif frame.header.opcode == websockets.OPCODE.CLOSE:
return self._handle_close(frame, source_conn, other_conn, is_server)
else:
return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
self.connections = {} # type: Dict[object, WSConnection]
def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
extensions = []
if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
extensions = [PerMessageDeflate.name]
fb = self.server_frame_buffer if is_server else self.client_frame_buffer
fb.append(frame)
self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
extensions=extensions)
self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
host=handshake_flow.request.host,
resource=handshake_flow.request.path,
extensions=extensions)
if frame.header.fin:
payload = b''.join(f.payload for f in fb)
original_chunk_sizes = [len(f.payload) for f in fb]
message_type = fb[0].header.opcode
compressed_message = fb[0].header.rsv1
fb.clear()
data = self.connections[self.server_conn].bytes_to_send()
self.connections[self.client_conn].receive_bytes(data)
websocket_message = WebSocketMessage(message_type, not is_server, payload)
length = len(websocket_message.content)
self.flow.messages.append(websocket_message)
self.channel.ask("websocket_message", self.flow)
event = next(self.connections[self.client_conn].events())
assert isinstance(event, events.ConnectionRequested)
if not self.flow.stream:
def get_chunk(payload):
if len(payload) == length:
# message has the same length, we can reuse the same sizes
pos = 0
for s in original_chunk_sizes:
yield payload[pos:pos + s]
pos += s
else:
# just re-chunk everything into 4kB frames
# header len = 4 bytes without masking key and 8 bytes with masking key
chunk_size = 4092 if is_server else 4088
chunks = range(0, len(payload), chunk_size)
for i in chunks:
yield payload[i:i + chunk_size]
self.connections[self.client_conn].accept(event)
self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send())
assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished)
frms = [
websockets.Frame(
payload=chunk,
opcode=frame.header.opcode,
mask=(False if is_server else 1),
masking_key=(b'' if is_server else os.urandom(4)))
for chunk in get_chunk(websocket_message.content)
]
if len(frms) > 0:
frms[-1].header.fin = True
else:
frms.append(websockets.Frame(
fin=True,
opcode=websockets.OPCODE.CONTINUE,
mask=(False if is_server else 1),
masking_key=(b'' if is_server else os.urandom(4))))
frms[0].header.opcode = message_type
frms[0].header.rsv1 = compressed_message
for frm in frms:
other_conn.send(bytes(frm))
else:
other_conn.send(bytes(frame))
elif self.flow.stream:
other_conn.send(bytes(frame))
def _handle_event(self, event, source_conn, other_conn, is_server):
if isinstance(event, events.DataReceived):
return self._handle_data_received(event, source_conn, other_conn, is_server)
elif isinstance(event, events.PingReceived):
return self._handle_ping_received(event, source_conn, other_conn, is_server)
elif isinstance(event, events.PongReceived):
return self._handle_pong_received(event, source_conn, other_conn, is_server)
elif isinstance(event, events.ConnectionFailed):
return self._handle_connection_closed(event, source_conn, other_conn, is_server)
elif isinstance(event, events.ConnectionFailed):
return self._handle_connection_failed(event)
# fail-safe for unhandled events
return True
def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
# just forward the ping/pong to the other side
other_conn.send(bytes(frame))
def _handle_data_received(self, event, source_conn, other_conn, is_server):
return True
def _handle_close(self, frame, source_conn, other_conn, is_server):
def _handle_ping_received(self, event, source_conn, other_conn, is_server):
# PING is automatically answered with a PONG by wsproto
# TODO: log this PING and its payload
self.connections[other_conn].ping(event.payload)
other_conn.send(self.connections[other_conn].bytes_to_send())
return True
def _handle_pong_received(self, event, source_conn, other_conn, is_server):
# TODO: log this PONG and its payload
self.connections[other_conn].pong(event.payload)
other_conn.send(self.connections[other_conn].bytes_to_send())
return True
def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
self.flow.close_sender = "server" if is_server else "client"
if len(frame.payload) >= 2:
code, = struct.unpack('!H', frame.payload[:2])
self.flow.close_code = code
self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
if len(frame.payload) > 2:
self.flow.close_reason = frame.payload[2:]
self.flow.close_code = event.code
self.flow.close_reason = event.reason
other_conn.send(bytes(frame))
print(self.connections[other_conn])
self.connections[other_conn].close(event.code, event.reason)
# initiate close handshake
return False
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
# unknown frame - just forward it
other_conn.send(bytes(frame))
def _handle_connection_failed(self, event):
raise exceptions.TcpException(repr(event))
sender = "server" if is_server else "client"
self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
return True
# def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
#
# fb = self.server_frame_buffer if is_server else self.client_frame_buffer
# fb.append(frame)
#
# if frame.header.fin:
# payload = b''.join(f.payload for f in fb)
# original_chunk_sizes = [len(f.payload) for f in fb]
# message_type = fb[0].header.opcode
# compressed_message = fb[0].header.rsv1
# fb.clear()
#
# websocket_message = WebSocketMessage(message_type, not is_server, payload)
# length = len(websocket_message.content)
# self.flow.messages.append(websocket_message)
# self.channel.ask("websocket_message", self.flow)
#
# if not self.flow.stream:
# def get_chunk(payload):
# if len(payload) == length:
# # message has the same length, we can reuse the same sizes
# pos = 0
# for s in original_chunk_sizes:
# yield payload[pos:pos + s]
# pos += s
# else:
# # just re-chunk everything into 4kB frames
# # header len = 4 bytes without masking key and 8 bytes with masking key
# chunk_size = 4092 if is_server else 4088
# chunks = range(0, len(payload), chunk_size)
# for i in chunks:
# yield payload[i:i + chunk_size]
#
# frms = [
# websockets.Frame(
# payload=chunk,
# opcode=frame.header.opcode,
# mask=(False if is_server else 1),
# masking_key=(b'' if is_server else os.urandom(4)))
# for chunk in get_chunk(websocket_message.content)
# ]
#
# if len(frms) > 0:
# frms[-1].header.fin = True
# else:
# frms.append(websockets.Frame(
# fin=True,
# opcode=websockets.OPCODE.CONTINUE,
# mask=(False if is_server else 1),
# masking_key=(b'' if is_server else os.urandom(4))))
#
# frms[0].header.opcode = message_type
# frms[0].header.rsv1 = compressed_message
#
# for frm in frms:
# other_conn.send(bytes(frm))
#
# else:
# other_conn.send(bytes(frame))
#
# elif self.flow.stream:
# other_conn.send(bytes(frame))
#
# return True
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
@ -153,27 +189,28 @@ class WebSocketLayer(base.Layer):
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
self.channel.ask("websocket_start", self.flow)
client = self.client_conn.connection
server = self.server_conn.connection
conns = [client, server]
conns = [c.connection for c in self.connections.keys()]
close_received = False
try:
while not self.channel.should_exit.is_set():
r = tcp.ssl_read_select(conns, 0.1)
for conn in r:
source_conn = self.client_conn if conn == client else self.server_conn
other_conn = self.server_conn if conn == client else self.client_conn
is_server = (conn == self.server_conn.connection)
source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
is_server = (source_conn == self.server_conn)
frame = websockets.Frame.from_file(source_conn.rfile)
self.connections[source_conn].receive_bytes(bytes(frame))
source_conn.send(self.connections[source_conn].bytes_to_send())
cont = self._handle_frame(frame, source_conn, other_conn, is_server)
if not cont:
if close_received:
return
else:
close_received = True
for event in self.connections[source_conn].events():
print('is_server:', is_server, 'event:', event)
if not self._handle_event(event, source_conn, other_conn, is_server):
if close_received:
break
else:
close_received = True
except (socket.error, exceptions.TcpException, SSL.Error) as e:
s = 'server' if is_server else 'client'
self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e)))

View File

@ -164,19 +164,19 @@ class TestSimple(_WebSocketTest):
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'self.client-foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'\xde\xad\xbe\xef'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
assert len(self.master.state.flows) == 2
@ -213,13 +213,13 @@ class TestSimpleTLS(_WebSocketTest):
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'self.client-foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
@ -234,7 +234,7 @@ class TestPing(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done')))
wfile.flush()
def test_ping(self):
@ -244,12 +244,12 @@ class TestPing(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == websockets.OPCODE.TEXT
assert frame.payload == b'pong-received'
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'done'
class TestPong(_WebSocketTest):
@ -266,7 +266,7 @@ class TestPong(_WebSocketTest):
def test_pong(self):
self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
@ -289,7 +289,7 @@ class TestClose(_WebSocketTest):
def test_close(self):
self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile)
@ -299,7 +299,7 @@ class TestClose(_WebSocketTest):
def test_close_payload_1(self):
self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile)
@ -309,7 +309,7 @@ class TestClose(_WebSocketTest):
def test_close_payload_2(self):
self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile)