prepare WebSocket stack to move to wsproto

This commit is contained in:
Thomas Kriechbaumer 2017-08-12 14:06:10 +02:00
parent 8e9194c2b4
commit 130021b76d
2 changed files with 146 additions and 109 deletions

View File

@ -3,9 +3,14 @@ import socket
import struct import struct
from OpenSSL import SSL from OpenSSL import SSL
from wsproto import events
from wsproto.connection import ConnectionType, WSConnection
from wsproto.extensions import PerMessageDeflate
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import flow from mitmproxy import flow
from mitmproxy.proxy.protocol import base from mitmproxy.proxy.protocol import base
from mitmproxy.net import http
from mitmproxy.net import tcp from mitmproxy.net import tcp
from mitmproxy.net import websockets from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
@ -44,108 +49,139 @@ class WebSocketLayer(base.Layer):
self.client_frame_buffer = [] self.client_frame_buffer = []
self.server_frame_buffer = [] self.server_frame_buffer = []
def _handle_frame(self, frame, source_conn, other_conn, is_server): self.connections = {} # type: Dict[object, WSConnection]
if frame.header.opcode & 0x8 == 0:
return self._handle_data_frame(frame, source_conn, other_conn, is_server)
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
elif frame.header.opcode == websockets.OPCODE.CLOSE:
return self._handle_close(frame, source_conn, other_conn, is_server)
else:
return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
def _handle_data_frame(self, frame, source_conn, other_conn, is_server): extensions = []
if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
extensions = [PerMessageDeflate.name]
fb = self.server_frame_buffer if is_server else self.client_frame_buffer self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
fb.append(frame) extensions=extensions)
self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
host=handshake_flow.request.host,
resource=handshake_flow.request.path,
extensions=extensions)
if frame.header.fin: data = self.connections[self.server_conn].bytes_to_send()
payload = b''.join(f.payload for f in fb) self.connections[self.client_conn].receive_bytes(data)
original_chunk_sizes = [len(f.payload) for f in fb]
message_type = fb[0].header.opcode
compressed_message = fb[0].header.rsv1
fb.clear()
websocket_message = WebSocketMessage(message_type, not is_server, payload) event = next(self.connections[self.client_conn].events())
length = len(websocket_message.content) assert isinstance(event, events.ConnectionRequested)
self.flow.messages.append(websocket_message)
self.channel.ask("websocket_message", self.flow)
if not self.flow.stream: self.connections[self.client_conn].accept(event)
def get_chunk(payload): self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send())
if len(payload) == length: assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished)
# message has the same length, we can reuse the same sizes
pos = 0
for s in original_chunk_sizes:
yield payload[pos:pos + s]
pos += s
else:
# just re-chunk everything into 4kB frames
# header len = 4 bytes without masking key and 8 bytes with masking key
chunk_size = 4092 if is_server else 4088
chunks = range(0, len(payload), chunk_size)
for i in chunks:
yield payload[i:i + chunk_size]
frms = [ def _handle_event(self, event, source_conn, other_conn, is_server):
websockets.Frame( if isinstance(event, events.DataReceived):
payload=chunk, return self._handle_data_received(event, source_conn, other_conn, is_server)
opcode=frame.header.opcode, elif isinstance(event, events.PingReceived):
mask=(False if is_server else 1), return self._handle_ping_received(event, source_conn, other_conn, is_server)
masking_key=(b'' if is_server else os.urandom(4))) elif isinstance(event, events.PongReceived):
for chunk in get_chunk(websocket_message.content) return self._handle_pong_received(event, source_conn, other_conn, is_server)
] elif isinstance(event, events.ConnectionFailed):
return self._handle_connection_closed(event, source_conn, other_conn, is_server)
if len(frms) > 0: elif isinstance(event, events.ConnectionFailed):
frms[-1].header.fin = True return self._handle_connection_failed(event)
else:
frms.append(websockets.Frame(
fin=True,
opcode=websockets.OPCODE.CONTINUE,
mask=(False if is_server else 1),
masking_key=(b'' if is_server else os.urandom(4))))
frms[0].header.opcode = message_type
frms[0].header.rsv1 = compressed_message
for frm in frms:
other_conn.send(bytes(frm))
else:
other_conn.send(bytes(frame))
elif self.flow.stream:
other_conn.send(bytes(frame))
# fail-safe for unhandled events
return True return True
def _handle_ping_pong(self, frame, source_conn, other_conn, is_server): def _handle_data_received(self, event, source_conn, other_conn, is_server):
# just forward the ping/pong to the other side
other_conn.send(bytes(frame))
return True return True
def _handle_close(self, frame, source_conn, other_conn, is_server): def _handle_ping_received(self, event, source_conn, other_conn, is_server):
# PING is automatically answered with a PONG by wsproto
# TODO: log this PING and its payload
self.connections[other_conn].ping(event.payload)
other_conn.send(self.connections[other_conn].bytes_to_send())
return True
def _handle_pong_received(self, event, source_conn, other_conn, is_server):
# TODO: log this PONG and its payload
self.connections[other_conn].pong(event.payload)
other_conn.send(self.connections[other_conn].bytes_to_send())
return True
def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
self.flow.close_sender = "server" if is_server else "client" self.flow.close_sender = "server" if is_server else "client"
if len(frame.payload) >= 2: self.flow.close_code = event.code
code, = struct.unpack('!H', frame.payload[:2]) self.flow.close_reason = event.reason
self.flow.close_code = code
self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
if len(frame.payload) > 2:
self.flow.close_reason = frame.payload[2:]
other_conn.send(bytes(frame)) print(self.connections[other_conn])
self.connections[other_conn].close(event.code, event.reason)
# initiate close handshake # initiate close handshake
return False return False
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server): def _handle_connection_failed(self, event):
# unknown frame - just forward it raise exceptions.TcpException(repr(event))
other_conn.send(bytes(frame))
sender = "server" if is_server else "client" # def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)]) #
# fb = self.server_frame_buffer if is_server else self.client_frame_buffer
return True # fb.append(frame)
#
# if frame.header.fin:
# payload = b''.join(f.payload for f in fb)
# original_chunk_sizes = [len(f.payload) for f in fb]
# message_type = fb[0].header.opcode
# compressed_message = fb[0].header.rsv1
# fb.clear()
#
# websocket_message = WebSocketMessage(message_type, not is_server, payload)
# length = len(websocket_message.content)
# self.flow.messages.append(websocket_message)
# self.channel.ask("websocket_message", self.flow)
#
# if not self.flow.stream:
# def get_chunk(payload):
# if len(payload) == length:
# # message has the same length, we can reuse the same sizes
# pos = 0
# for s in original_chunk_sizes:
# yield payload[pos:pos + s]
# pos += s
# else:
# # just re-chunk everything into 4kB frames
# # header len = 4 bytes without masking key and 8 bytes with masking key
# chunk_size = 4092 if is_server else 4088
# chunks = range(0, len(payload), chunk_size)
# for i in chunks:
# yield payload[i:i + chunk_size]
#
# frms = [
# websockets.Frame(
# payload=chunk,
# opcode=frame.header.opcode,
# mask=(False if is_server else 1),
# masking_key=(b'' if is_server else os.urandom(4)))
# for chunk in get_chunk(websocket_message.content)
# ]
#
# if len(frms) > 0:
# frms[-1].header.fin = True
# else:
# frms.append(websockets.Frame(
# fin=True,
# opcode=websockets.OPCODE.CONTINUE,
# mask=(False if is_server else 1),
# masking_key=(b'' if is_server else os.urandom(4))))
#
# frms[0].header.opcode = message_type
# frms[0].header.rsv1 = compressed_message
#
# for frm in frms:
# other_conn.send(bytes(frm))
#
# else:
# other_conn.send(bytes(frame))
#
# elif self.flow.stream:
# other_conn.send(bytes(frame))
#
# return True
def __call__(self): def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self) self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
@ -153,25 +189,26 @@ class WebSocketLayer(base.Layer):
self.handshake_flow.metadata['websocket_flow'] = self.flow.id self.handshake_flow.metadata['websocket_flow'] = self.flow.id
self.channel.ask("websocket_start", self.flow) self.channel.ask("websocket_start", self.flow)
client = self.client_conn.connection conns = [c.connection for c in self.connections.keys()]
server = self.server_conn.connection
conns = [client, server]
close_received = False close_received = False
try: try:
while not self.channel.should_exit.is_set(): while not self.channel.should_exit.is_set():
r = tcp.ssl_read_select(conns, 0.1) r = tcp.ssl_read_select(conns, 0.1)
for conn in r: for conn in r:
source_conn = self.client_conn if conn == client else self.server_conn source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
other_conn = self.server_conn if conn == client else self.client_conn other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
is_server = (conn == self.server_conn.connection) is_server = (source_conn == self.server_conn)
frame = websockets.Frame.from_file(source_conn.rfile) frame = websockets.Frame.from_file(source_conn.rfile)
self.connections[source_conn].receive_bytes(bytes(frame))
source_conn.send(self.connections[source_conn].bytes_to_send())
cont = self._handle_frame(frame, source_conn, other_conn, is_server) for event in self.connections[source_conn].events():
if not cont: print('is_server:', is_server, 'event:', event)
if not self._handle_event(event, source_conn, other_conn, is_server):
if close_received: if close_received:
return break
else: else:
close_received = True close_received = True
except (socket.error, exceptions.TcpException, SSL.Error) as e: except (socket.error, exceptions.TcpException, SSL.Error) as e:

View File

@ -164,19 +164,19 @@ class TestSimple(_WebSocketTest):
frame = websockets.Frame.from_file(self.client.rfile) frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar' assert frame.payload == b'server-foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush() self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile) frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'self.client-foobar' assert frame.payload == b'self.client-foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
self.client.wfile.flush() self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile) frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'\xde\xad\xbe\xef' assert frame.payload == b'\xde\xad\xbe\xef'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush() self.client.wfile.flush()
assert len(self.master.state.flows) == 2 assert len(self.master.state.flows) == 2
@ -213,13 +213,13 @@ class TestSimpleTLS(_WebSocketTest):
frame = websockets.Frame.from_file(self.client.rfile) frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar' assert frame.payload == b'server-foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush() self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile) frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'self.client-foobar' assert frame.payload == b'self.client-foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush() self.client.wfile.flush()
@ -234,7 +234,7 @@ class TestPing(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PONG assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar' assert frame.payload == b'foobar'
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done')))
wfile.flush() wfile.flush()
def test_ping(self): def test_ping(self):
@ -244,12 +244,12 @@ class TestPing(_WebSocketTest):
assert frame.header.opcode == websockets.OPCODE.PING assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar' assert frame.payload == b'foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
self.client.wfile.flush() self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile) frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == websockets.OPCODE.TEXT assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'pong-received' assert frame.payload == b'done'
class TestPong(_WebSocketTest): class TestPong(_WebSocketTest):
@ -266,7 +266,7 @@ class TestPong(_WebSocketTest):
def test_pong(self): def test_pong(self):
self.setup_connection() self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
self.client.wfile.flush() self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile) frame = websockets.Frame.from_file(self.client.rfile)
@ -289,7 +289,7 @@ class TestClose(_WebSocketTest):
def test_close(self): def test_close(self):
self.setup_connection() self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush() self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile) websockets.Frame.from_file(self.client.rfile)
@ -299,7 +299,7 @@ class TestClose(_WebSocketTest):
def test_close_payload_1(self): def test_close_payload_1(self):
self.setup_connection() self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
self.client.wfile.flush() self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile) websockets.Frame.from_file(self.client.rfile)
@ -309,7 +309,7 @@ class TestClose(_WebSocketTest):
def test_close_payload_2(self): def test_close_payload_2(self):
self.setup_connection() self.setup_connection()
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
self.client.wfile.flush() self.client.wfile.flush()
websockets.Frame.from_file(self.client.rfile) websockets.Frame.from_file(self.client.rfile)