diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 19546eb2e..d1abd1346 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -3,9 +3,14 @@ import socket import struct from OpenSSL import SSL +from wsproto import events +from wsproto.connection import ConnectionType, WSConnection +from wsproto.extensions import PerMessageDeflate + from mitmproxy import exceptions from mitmproxy import flow from mitmproxy.proxy.protocol import base +from mitmproxy.net import http from mitmproxy.net import tcp from mitmproxy.net import websockets from mitmproxy.websocket import WebSocketFlow, WebSocketMessage @@ -44,108 +49,139 @@ class WebSocketLayer(base.Layer): self.client_frame_buffer = [] self.server_frame_buffer = [] - def _handle_frame(self, frame, source_conn, other_conn, is_server): - if frame.header.opcode & 0x8 == 0: - return self._handle_data_frame(frame, source_conn, other_conn, is_server) - elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG): - return self._handle_ping_pong(frame, source_conn, other_conn, is_server) - elif frame.header.opcode == websockets.OPCODE.CLOSE: - return self._handle_close(frame, source_conn, other_conn, is_server) - else: - return self._handle_unknown_frame(frame, source_conn, other_conn, is_server) + self.connections = {} # type: Dict[object, WSConnection] - def _handle_data_frame(self, frame, source_conn, other_conn, is_server): + extensions = [] + if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers: + if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']: + extensions = [PerMessageDeflate.name] - fb = self.server_frame_buffer if is_server else self.client_frame_buffer - fb.append(frame) + self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER, + extensions=extensions) + self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT, + host=handshake_flow.request.host, + resource=handshake_flow.request.path, + extensions=extensions) - if frame.header.fin: - payload = b''.join(f.payload for f in fb) - original_chunk_sizes = [len(f.payload) for f in fb] - message_type = fb[0].header.opcode - compressed_message = fb[0].header.rsv1 - fb.clear() + data = self.connections[self.server_conn].bytes_to_send() + self.connections[self.client_conn].receive_bytes(data) - websocket_message = WebSocketMessage(message_type, not is_server, payload) - length = len(websocket_message.content) - self.flow.messages.append(websocket_message) - self.channel.ask("websocket_message", self.flow) + event = next(self.connections[self.client_conn].events()) + assert isinstance(event, events.ConnectionRequested) - if not self.flow.stream: - def get_chunk(payload): - if len(payload) == length: - # message has the same length, we can reuse the same sizes - pos = 0 - for s in original_chunk_sizes: - yield payload[pos:pos + s] - pos += s - else: - # just re-chunk everything into 4kB frames - # header len = 4 bytes without masking key and 8 bytes with masking key - chunk_size = 4092 if is_server else 4088 - chunks = range(0, len(payload), chunk_size) - for i in chunks: - yield payload[i:i + chunk_size] + self.connections[self.client_conn].accept(event) + self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send()) + assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished) - frms = [ - websockets.Frame( - payload=chunk, - opcode=frame.header.opcode, - mask=(False if is_server else 1), - masking_key=(b'' if is_server else os.urandom(4))) - for chunk in get_chunk(websocket_message.content) - ] - - if len(frms) > 0: - frms[-1].header.fin = True - else: - frms.append(websockets.Frame( - fin=True, - opcode=websockets.OPCODE.CONTINUE, - mask=(False if is_server else 1), - masking_key=(b'' if is_server else os.urandom(4)))) - - frms[0].header.opcode = message_type - frms[0].header.rsv1 = compressed_message - - for frm in frms: - other_conn.send(bytes(frm)) - - else: - other_conn.send(bytes(frame)) - - elif self.flow.stream: - other_conn.send(bytes(frame)) + def _handle_event(self, event, source_conn, other_conn, is_server): + if isinstance(event, events.DataReceived): + return self._handle_data_received(event, source_conn, other_conn, is_server) + elif isinstance(event, events.PingReceived): + return self._handle_ping_received(event, source_conn, other_conn, is_server) + elif isinstance(event, events.PongReceived): + return self._handle_pong_received(event, source_conn, other_conn, is_server) + elif isinstance(event, events.ConnectionFailed): + return self._handle_connection_closed(event, source_conn, other_conn, is_server) + elif isinstance(event, events.ConnectionFailed): + return self._handle_connection_failed(event) + # fail-safe for unhandled events return True - def _handle_ping_pong(self, frame, source_conn, other_conn, is_server): - # just forward the ping/pong to the other side - other_conn.send(bytes(frame)) + def _handle_data_received(self, event, source_conn, other_conn, is_server): return True - def _handle_close(self, frame, source_conn, other_conn, is_server): + def _handle_ping_received(self, event, source_conn, other_conn, is_server): + # PING is automatically answered with a PONG by wsproto + # TODO: log this PING and its payload + self.connections[other_conn].ping(event.payload) + other_conn.send(self.connections[other_conn].bytes_to_send()) + return True + + def _handle_pong_received(self, event, source_conn, other_conn, is_server): + # TODO: log this PONG and its payload + self.connections[other_conn].pong(event.payload) + other_conn.send(self.connections[other_conn].bytes_to_send()) + return True + + def _handle_connection_closed(self, event, source_conn, other_conn, is_server): self.flow.close_sender = "server" if is_server else "client" - if len(frame.payload) >= 2: - code, = struct.unpack('!H', frame.payload[:2]) - self.flow.close_code = code - self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code') - if len(frame.payload) > 2: - self.flow.close_reason = frame.payload[2:] + self.flow.close_code = event.code + self.flow.close_reason = event.reason - other_conn.send(bytes(frame)) + print(self.connections[other_conn]) + self.connections[other_conn].close(event.code, event.reason) # initiate close handshake return False - def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server): - # unknown frame - just forward it - other_conn.send(bytes(frame)) + def _handle_connection_failed(self, event): + raise exceptions.TcpException(repr(event)) - sender = "server" if is_server else "client" - self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)]) - - return True + # def _handle_data_frame(self, frame, source_conn, other_conn, is_server): + # + # fb = self.server_frame_buffer if is_server else self.client_frame_buffer + # fb.append(frame) + # + # if frame.header.fin: + # payload = b''.join(f.payload for f in fb) + # original_chunk_sizes = [len(f.payload) for f in fb] + # message_type = fb[0].header.opcode + # compressed_message = fb[0].header.rsv1 + # fb.clear() + # + # websocket_message = WebSocketMessage(message_type, not is_server, payload) + # length = len(websocket_message.content) + # self.flow.messages.append(websocket_message) + # self.channel.ask("websocket_message", self.flow) + # + # if not self.flow.stream: + # def get_chunk(payload): + # if len(payload) == length: + # # message has the same length, we can reuse the same sizes + # pos = 0 + # for s in original_chunk_sizes: + # yield payload[pos:pos + s] + # pos += s + # else: + # # just re-chunk everything into 4kB frames + # # header len = 4 bytes without masking key and 8 bytes with masking key + # chunk_size = 4092 if is_server else 4088 + # chunks = range(0, len(payload), chunk_size) + # for i in chunks: + # yield payload[i:i + chunk_size] + # + # frms = [ + # websockets.Frame( + # payload=chunk, + # opcode=frame.header.opcode, + # mask=(False if is_server else 1), + # masking_key=(b'' if is_server else os.urandom(4))) + # for chunk in get_chunk(websocket_message.content) + # ] + # + # if len(frms) > 0: + # frms[-1].header.fin = True + # else: + # frms.append(websockets.Frame( + # fin=True, + # opcode=websockets.OPCODE.CONTINUE, + # mask=(False if is_server else 1), + # masking_key=(b'' if is_server else os.urandom(4)))) + # + # frms[0].header.opcode = message_type + # frms[0].header.rsv1 = compressed_message + # + # for frm in frms: + # other_conn.send(bytes(frm)) + # + # else: + # other_conn.send(bytes(frame)) + # + # elif self.flow.stream: + # other_conn.send(bytes(frame)) + # + # return True def __call__(self): self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self) @@ -153,27 +189,28 @@ class WebSocketLayer(base.Layer): self.handshake_flow.metadata['websocket_flow'] = self.flow.id self.channel.ask("websocket_start", self.flow) - client = self.client_conn.connection - server = self.server_conn.connection - conns = [client, server] + conns = [c.connection for c in self.connections.keys()] close_received = False try: while not self.channel.should_exit.is_set(): r = tcp.ssl_read_select(conns, 0.1) for conn in r: - source_conn = self.client_conn if conn == client else self.server_conn - other_conn = self.server_conn if conn == client else self.client_conn - is_server = (conn == self.server_conn.connection) + source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn + other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn + is_server = (source_conn == self.server_conn) frame = websockets.Frame.from_file(source_conn.rfile) + self.connections[source_conn].receive_bytes(bytes(frame)) + source_conn.send(self.connections[source_conn].bytes_to_send()) - cont = self._handle_frame(frame, source_conn, other_conn, is_server) - if not cont: - if close_received: - return - else: - close_received = True + for event in self.connections[source_conn].events(): + print('is_server:', is_server, 'event:', event) + if not self._handle_event(event, source_conn, other_conn, is_server): + if close_received: + break + else: + close_received = True except (socket.error, exceptions.TcpException, SSL.Error) as e: s = 'server' if is_server else 'client' self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e))) diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 460d85f84..14dd74056 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -164,19 +164,19 @@ class TestSimple(_WebSocketTest): frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'self.client-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'\xde\xad\xbe\xef' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() assert len(self.master.state.flows) == 2 @@ -213,13 +213,13 @@ class TestSimpleTLS(_WebSocketTest): frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'self.client-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() @@ -234,7 +234,7 @@ class TestPing(_WebSocketTest): assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' - wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done'))) wfile.flush() def test_ping(self): @@ -244,12 +244,12 @@ class TestPing(_WebSocketTest): assert frame.header.opcode == websockets.OPCODE.PING assert frame.payload == b'foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == websockets.OPCODE.TEXT - assert frame.payload == b'pong-received' + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'done' class TestPong(_WebSocketTest): @@ -266,7 +266,7 @@ class TestPong(_WebSocketTest): def test_pong(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) @@ -289,7 +289,7 @@ class TestClose(_WebSocketTest): def test_close(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -299,7 +299,7 @@ class TestClose(_WebSocketTest): def test_close_payload_1(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -309,7 +309,7 @@ class TestClose(_WebSocketTest): def test_close_payload_2(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile)