From 130021b76d781f0ebb43928aa8083c8b8d560882 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 12 Aug 2017 14:06:10 +0200 Subject: [PATCH 1/5] prepare WebSocket stack to move to wsproto --- mitmproxy/proxy/protocol/websocket.py | 229 ++++++++++-------- .../proxy/protocol/test_websocket.py | 26 +- 2 files changed, 146 insertions(+), 109 deletions(-) diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 19546eb2e..d1abd1346 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -3,9 +3,14 @@ import socket import struct from OpenSSL import SSL +from wsproto import events +from wsproto.connection import ConnectionType, WSConnection +from wsproto.extensions import PerMessageDeflate + from mitmproxy import exceptions from mitmproxy import flow from mitmproxy.proxy.protocol import base +from mitmproxy.net import http from mitmproxy.net import tcp from mitmproxy.net import websockets from mitmproxy.websocket import WebSocketFlow, WebSocketMessage @@ -44,108 +49,139 @@ class WebSocketLayer(base.Layer): self.client_frame_buffer = [] self.server_frame_buffer = [] - def _handle_frame(self, frame, source_conn, other_conn, is_server): - if frame.header.opcode & 0x8 == 0: - return self._handle_data_frame(frame, source_conn, other_conn, is_server) - elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG): - return self._handle_ping_pong(frame, source_conn, other_conn, is_server) - elif frame.header.opcode == websockets.OPCODE.CLOSE: - return self._handle_close(frame, source_conn, other_conn, is_server) - else: - return self._handle_unknown_frame(frame, source_conn, other_conn, is_server) + self.connections = {} # type: Dict[object, WSConnection] - def _handle_data_frame(self, frame, source_conn, other_conn, is_server): + extensions = [] + if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers: + if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']: + extensions = [PerMessageDeflate.name] - fb = self.server_frame_buffer if is_server else self.client_frame_buffer - fb.append(frame) + self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER, + extensions=extensions) + self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT, + host=handshake_flow.request.host, + resource=handshake_flow.request.path, + extensions=extensions) - if frame.header.fin: - payload = b''.join(f.payload for f in fb) - original_chunk_sizes = [len(f.payload) for f in fb] - message_type = fb[0].header.opcode - compressed_message = fb[0].header.rsv1 - fb.clear() + data = self.connections[self.server_conn].bytes_to_send() + self.connections[self.client_conn].receive_bytes(data) - websocket_message = WebSocketMessage(message_type, not is_server, payload) - length = len(websocket_message.content) - self.flow.messages.append(websocket_message) - self.channel.ask("websocket_message", self.flow) + event = next(self.connections[self.client_conn].events()) + assert isinstance(event, events.ConnectionRequested) - if not self.flow.stream: - def get_chunk(payload): - if len(payload) == length: - # message has the same length, we can reuse the same sizes - pos = 0 - for s in original_chunk_sizes: - yield payload[pos:pos + s] - pos += s - else: - # just re-chunk everything into 4kB frames - # header len = 4 bytes without masking key and 8 bytes with masking key - chunk_size = 4092 if is_server else 4088 - chunks = range(0, len(payload), chunk_size) - for i in chunks: - yield payload[i:i + chunk_size] + self.connections[self.client_conn].accept(event) + self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send()) + assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished) - frms = [ - websockets.Frame( - payload=chunk, - opcode=frame.header.opcode, - mask=(False if is_server else 1), - masking_key=(b'' if is_server else os.urandom(4))) - for chunk in get_chunk(websocket_message.content) - ] - - if len(frms) > 0: - frms[-1].header.fin = True - else: - frms.append(websockets.Frame( - fin=True, - opcode=websockets.OPCODE.CONTINUE, - mask=(False if is_server else 1), - masking_key=(b'' if is_server else os.urandom(4)))) - - frms[0].header.opcode = message_type - frms[0].header.rsv1 = compressed_message - - for frm in frms: - other_conn.send(bytes(frm)) - - else: - other_conn.send(bytes(frame)) - - elif self.flow.stream: - other_conn.send(bytes(frame)) + def _handle_event(self, event, source_conn, other_conn, is_server): + if isinstance(event, events.DataReceived): + return self._handle_data_received(event, source_conn, other_conn, is_server) + elif isinstance(event, events.PingReceived): + return self._handle_ping_received(event, source_conn, other_conn, is_server) + elif isinstance(event, events.PongReceived): + return self._handle_pong_received(event, source_conn, other_conn, is_server) + elif isinstance(event, events.ConnectionFailed): + return self._handle_connection_closed(event, source_conn, other_conn, is_server) + elif isinstance(event, events.ConnectionFailed): + return self._handle_connection_failed(event) + # fail-safe for unhandled events return True - def _handle_ping_pong(self, frame, source_conn, other_conn, is_server): - # just forward the ping/pong to the other side - other_conn.send(bytes(frame)) + def _handle_data_received(self, event, source_conn, other_conn, is_server): return True - def _handle_close(self, frame, source_conn, other_conn, is_server): + def _handle_ping_received(self, event, source_conn, other_conn, is_server): + # PING is automatically answered with a PONG by wsproto + # TODO: log this PING and its payload + self.connections[other_conn].ping(event.payload) + other_conn.send(self.connections[other_conn].bytes_to_send()) + return True + + def _handle_pong_received(self, event, source_conn, other_conn, is_server): + # TODO: log this PONG and its payload + self.connections[other_conn].pong(event.payload) + other_conn.send(self.connections[other_conn].bytes_to_send()) + return True + + def _handle_connection_closed(self, event, source_conn, other_conn, is_server): self.flow.close_sender = "server" if is_server else "client" - if len(frame.payload) >= 2: - code, = struct.unpack('!H', frame.payload[:2]) - self.flow.close_code = code - self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code') - if len(frame.payload) > 2: - self.flow.close_reason = frame.payload[2:] + self.flow.close_code = event.code + self.flow.close_reason = event.reason - other_conn.send(bytes(frame)) + print(self.connections[other_conn]) + self.connections[other_conn].close(event.code, event.reason) # initiate close handshake return False - def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server): - # unknown frame - just forward it - other_conn.send(bytes(frame)) + def _handle_connection_failed(self, event): + raise exceptions.TcpException(repr(event)) - sender = "server" if is_server else "client" - self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)]) - - return True + # def _handle_data_frame(self, frame, source_conn, other_conn, is_server): + # + # fb = self.server_frame_buffer if is_server else self.client_frame_buffer + # fb.append(frame) + # + # if frame.header.fin: + # payload = b''.join(f.payload for f in fb) + # original_chunk_sizes = [len(f.payload) for f in fb] + # message_type = fb[0].header.opcode + # compressed_message = fb[0].header.rsv1 + # fb.clear() + # + # websocket_message = WebSocketMessage(message_type, not is_server, payload) + # length = len(websocket_message.content) + # self.flow.messages.append(websocket_message) + # self.channel.ask("websocket_message", self.flow) + # + # if not self.flow.stream: + # def get_chunk(payload): + # if len(payload) == length: + # # message has the same length, we can reuse the same sizes + # pos = 0 + # for s in original_chunk_sizes: + # yield payload[pos:pos + s] + # pos += s + # else: + # # just re-chunk everything into 4kB frames + # # header len = 4 bytes without masking key and 8 bytes with masking key + # chunk_size = 4092 if is_server else 4088 + # chunks = range(0, len(payload), chunk_size) + # for i in chunks: + # yield payload[i:i + chunk_size] + # + # frms = [ + # websockets.Frame( + # payload=chunk, + # opcode=frame.header.opcode, + # mask=(False if is_server else 1), + # masking_key=(b'' if is_server else os.urandom(4))) + # for chunk in get_chunk(websocket_message.content) + # ] + # + # if len(frms) > 0: + # frms[-1].header.fin = True + # else: + # frms.append(websockets.Frame( + # fin=True, + # opcode=websockets.OPCODE.CONTINUE, + # mask=(False if is_server else 1), + # masking_key=(b'' if is_server else os.urandom(4)))) + # + # frms[0].header.opcode = message_type + # frms[0].header.rsv1 = compressed_message + # + # for frm in frms: + # other_conn.send(bytes(frm)) + # + # else: + # other_conn.send(bytes(frame)) + # + # elif self.flow.stream: + # other_conn.send(bytes(frame)) + # + # return True def __call__(self): self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self) @@ -153,27 +189,28 @@ class WebSocketLayer(base.Layer): self.handshake_flow.metadata['websocket_flow'] = self.flow.id self.channel.ask("websocket_start", self.flow) - client = self.client_conn.connection - server = self.server_conn.connection - conns = [client, server] + conns = [c.connection for c in self.connections.keys()] close_received = False try: while not self.channel.should_exit.is_set(): r = tcp.ssl_read_select(conns, 0.1) for conn in r: - source_conn = self.client_conn if conn == client else self.server_conn - other_conn = self.server_conn if conn == client else self.client_conn - is_server = (conn == self.server_conn.connection) + source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn + other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn + is_server = (source_conn == self.server_conn) frame = websockets.Frame.from_file(source_conn.rfile) + self.connections[source_conn].receive_bytes(bytes(frame)) + source_conn.send(self.connections[source_conn].bytes_to_send()) - cont = self._handle_frame(frame, source_conn, other_conn, is_server) - if not cont: - if close_received: - return - else: - close_received = True + for event in self.connections[source_conn].events(): + print('is_server:', is_server, 'event:', event) + if not self._handle_event(event, source_conn, other_conn, is_server): + if close_received: + break + else: + close_received = True except (socket.error, exceptions.TcpException, SSL.Error) as e: s = 'server' if is_server else 'client' self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e))) diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 460d85f84..14dd74056 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -164,19 +164,19 @@ class TestSimple(_WebSocketTest): frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'self.client-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'\xde\xad\xbe\xef' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() assert len(self.master.state.flows) == 2 @@ -213,13 +213,13 @@ class TestSimpleTLS(_WebSocketTest): frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'self.client-foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() @@ -234,7 +234,7 @@ class TestPing(_WebSocketTest): assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' - wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done'))) wfile.flush() def test_ping(self): @@ -244,12 +244,12 @@ class TestPing(_WebSocketTest): assert frame.header.opcode == websockets.OPCODE.PING assert frame.payload == b'foobar' - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == websockets.OPCODE.TEXT - assert frame.payload == b'pong-received' + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'done' class TestPong(_WebSocketTest): @@ -266,7 +266,7 @@ class TestPong(_WebSocketTest): def test_pong(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) @@ -289,7 +289,7 @@ class TestClose(_WebSocketTest): def test_close(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -299,7 +299,7 @@ class TestClose(_WebSocketTest): def test_close_payload_1(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) @@ -309,7 +309,7 @@ class TestClose(_WebSocketTest): def test_close_payload_2(self): self.setup_connection() - self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) self.client.wfile.flush() websockets.Frame.from_file(self.client.rfile) From 5214f544e7b690dea2a45cb4cda44bbffec9a77e Mon Sep 17 00:00:00 2001 From: Ujjwal Verma Date: Thu, 17 Aug 2017 21:12:07 +0530 Subject: [PATCH 2/5] Use wsproto for websockets --- mitmproxy/addons/dumper.py | 2 + mitmproxy/proxy/protocol/websocket.py | 159 ++++++++---------- mitmproxy/tools/console/consoleaddons.py | 2 +- setup.cfg | 9 +- .../proxy/protocol/test_websocket.py | 122 ++++++++++++-- 5 files changed, 185 insertions(+), 109 deletions(-) diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py index 54526d5b2..48bc81187 100644 --- a/mitmproxy/addons/dumper.py +++ b/mitmproxy/addons/dumper.py @@ -234,6 +234,8 @@ class Dumper: message = f.messages[-1] self.echo(f.message_info(message)) if ctx.options.flow_detail >= 3: + message = message.from_state(message.get_state()) + message.content = message.content.encode() if isinstance(message.content, str) else message.content self._echo_message(message) def websocket_end(self, f): diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index d1abd1346..54d8120de 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -1,19 +1,18 @@ -import os import socket -import struct from OpenSSL import SSL from wsproto import events from wsproto.connection import ConnectionType, WSConnection from wsproto.extensions import PerMessageDeflate +from wsproto.frame_protocol import Opcode from mitmproxy import exceptions from mitmproxy import flow from mitmproxy.proxy.protocol import base -from mitmproxy.net import http from mitmproxy.net import tcp from mitmproxy.net import websockets from mitmproxy.websocket import WebSocketFlow, WebSocketMessage +from mitmproxy.utils import strutils class WebSocketLayer(base.Layer): @@ -54,14 +53,16 @@ class WebSocketLayer(base.Layer): extensions = [] if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers: if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']: - extensions = [PerMessageDeflate.name] - + extensions = [PerMessageDeflate()] self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER, extensions=extensions) self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT, host=handshake_flow.request.host, resource=handshake_flow.request.path, extensions=extensions) + if extensions: + for conn in self.connections.values(): + conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions']) data = self.connections[self.server_conn].bytes_to_send() self.connections[self.client_conn].receive_bytes(data) @@ -80,28 +81,78 @@ class WebSocketLayer(base.Layer): return self._handle_ping_received(event, source_conn, other_conn, is_server) elif isinstance(event, events.PongReceived): return self._handle_pong_received(event, source_conn, other_conn, is_server) - elif isinstance(event, events.ConnectionFailed): + elif isinstance(event, events.ConnectionClosed): return self._handle_connection_closed(event, source_conn, other_conn, is_server) - elif isinstance(event, events.ConnectionFailed): - return self._handle_connection_failed(event) # fail-safe for unhandled events - return True + return True # pragma: no cover def _handle_data_received(self, event, source_conn, other_conn, is_server): + fb = self.server_frame_buffer if is_server else self.client_frame_buffer + fb.append(event.data) + + if event.message_finished: + original_chunk_sizes = [len(f) for f in fb] + message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY + if message_type == Opcode.TEXT: + payload = ''.join(fb) + else: + payload = b''.join(fb) + fb.clear() + + websocket_message = WebSocketMessage(message_type, not is_server, payload) + length = len(websocket_message.content) + self.flow.messages.append(websocket_message) + self.channel.ask("websocket_message", self.flow) + + if not self.flow.stream: + def get_chunk(payload): + if len(payload) == length: + # message has the same length, we can reuse the same sizes + pos = 0 + for s in original_chunk_sizes: + yield (payload[pos:pos + s], True if pos + s == length else False) + pos += s + else: + # just re-chunk everything into 4kB frames + # header len = 4 bytes without masking key and 8 bytes with masking key + chunk_size = 4092 if is_server else 4088 + chunks = range(0, len(payload), chunk_size) + for i in chunks: + yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False) + + for chunk, final in get_chunk(websocket_message.content): + self.connections[other_conn].send_data(chunk, final) + other_conn.send(self.connections[other_conn].bytes_to_send()) + + else: + self.connections[other_conn].send_data(event.data, event.message_finished) + other_conn.send(self.connections[other_conn].bytes_to_send()) + + elif self.flow.stream: + self.connections[other_conn].send_data(event.data, event.message_finished) + other_conn.send(self.connections[other_conn].bytes_to_send()) + return True def _handle_ping_received(self, event, source_conn, other_conn, is_server): # PING is automatically answered with a PONG by wsproto - # TODO: log this PING and its payload - self.connections[other_conn].ping(event.payload) + self.connections[other_conn].ping() other_conn.send(self.connections[other_conn].bytes_to_send()) + source_conn.send(self.connections[source_conn].bytes_to_send()) + self.log( + "Ping Received from {}".format("server" if is_server else "client"), + "info", + [strutils.bytes_to_escaped_str(bytes(event.payload))] + ) return True def _handle_pong_received(self, event, source_conn, other_conn, is_server): - # TODO: log this PONG and its payload - self.connections[other_conn].pong(event.payload) - other_conn.send(self.connections[other_conn].bytes_to_send()) + self.log( + "Pong Received from {}".format("server" if is_server else "client"), + "info", + [strutils.bytes_to_escaped_str(bytes(event.payload))] + ) return True def _handle_connection_closed(self, event, source_conn, other_conn, is_server): @@ -109,80 +160,12 @@ class WebSocketLayer(base.Layer): self.flow.close_code = event.code self.flow.close_reason = event.reason - print(self.connections[other_conn]) self.connections[other_conn].close(event.code, event.reason) + other_conn.send(self.connections[other_conn].bytes_to_send()) + source_conn.send(self.connections[source_conn].bytes_to_send()) - # initiate close handshake return False - def _handle_connection_failed(self, event): - raise exceptions.TcpException(repr(event)) - - # def _handle_data_frame(self, frame, source_conn, other_conn, is_server): - # - # fb = self.server_frame_buffer if is_server else self.client_frame_buffer - # fb.append(frame) - # - # if frame.header.fin: - # payload = b''.join(f.payload for f in fb) - # original_chunk_sizes = [len(f.payload) for f in fb] - # message_type = fb[0].header.opcode - # compressed_message = fb[0].header.rsv1 - # fb.clear() - # - # websocket_message = WebSocketMessage(message_type, not is_server, payload) - # length = len(websocket_message.content) - # self.flow.messages.append(websocket_message) - # self.channel.ask("websocket_message", self.flow) - # - # if not self.flow.stream: - # def get_chunk(payload): - # if len(payload) == length: - # # message has the same length, we can reuse the same sizes - # pos = 0 - # for s in original_chunk_sizes: - # yield payload[pos:pos + s] - # pos += s - # else: - # # just re-chunk everything into 4kB frames - # # header len = 4 bytes without masking key and 8 bytes with masking key - # chunk_size = 4092 if is_server else 4088 - # chunks = range(0, len(payload), chunk_size) - # for i in chunks: - # yield payload[i:i + chunk_size] - # - # frms = [ - # websockets.Frame( - # payload=chunk, - # opcode=frame.header.opcode, - # mask=(False if is_server else 1), - # masking_key=(b'' if is_server else os.urandom(4))) - # for chunk in get_chunk(websocket_message.content) - # ] - # - # if len(frms) > 0: - # frms[-1].header.fin = True - # else: - # frms.append(websockets.Frame( - # fin=True, - # opcode=websockets.OPCODE.CONTINUE, - # mask=(False if is_server else 1), - # masking_key=(b'' if is_server else os.urandom(4)))) - # - # frms[0].header.opcode = message_type - # frms[0].header.rsv1 = compressed_message - # - # for frm in frms: - # other_conn.send(bytes(frm)) - # - # else: - # other_conn.send(bytes(frame)) - # - # elif self.flow.stream: - # other_conn.send(bytes(frame)) - # - # return True - def __call__(self): self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self) self.flow.metadata['websocket_handshake'] = self.handshake_flow.id @@ -204,12 +187,12 @@ class WebSocketLayer(base.Layer): self.connections[source_conn].receive_bytes(bytes(frame)) source_conn.send(self.connections[source_conn].bytes_to_send()) + if close_received: + return + for event in self.connections[source_conn].events(): - print('is_server:', is_server, 'event:', event) if not self._handle_event(event, source_conn, other_conn, is_server): - if close_received: - break - else: + if not close_received: close_received = True except (socket.error, exceptions.TcpException, SSL.Error) as e: s = 'server' if is_server else 'client' diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 1bda219f3..8233d45e4 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -49,7 +49,7 @@ class UnsupportedLog: def websocket_message(self, f): message = f.messages[-1] signals.add_log(f.message_info(message), "info") - signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug") + signals.add_log(message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content), "debug") def websocket_end(self, f): signals.add_log("WebSocket connection closed by {}: {} {}, {}".format( diff --git a/setup.cfg b/setup.cfg index eaabfa12c..fd31d15b5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,7 +21,13 @@ exclude_lines = [tool:full_coverage] exclude = - mitmproxy/proxy/protocol/ + mitmproxy/proxy/protocol/base.py + mitmproxy/proxy/protocol/http.py + mitmproxy/proxy/protocol/http1.py + mitmproxy/proxy/protocol/http2.py + mitmproxy/proxy/protocol/http_replay.py + mitmproxy/proxy/protocol/rawtcp.py + mitmproxy/proxy/protocol/tls.py mitmproxy/proxy/root_context.py mitmproxy/proxy/server.py mitmproxy/tools/ @@ -64,7 +70,6 @@ exclude = mitmproxy/proxy/protocol/http_replay.py mitmproxy/proxy/protocol/rawtcp.py mitmproxy/proxy/protocol/tls.py - mitmproxy/proxy/protocol/websocket.py mitmproxy/proxy/root_context.py mitmproxy/proxy/server.py mitmproxy/stateobject.py diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 14dd74056..a7acdc4db 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -1,5 +1,6 @@ import pytest import os +import struct import tempfile import traceback @@ -33,6 +34,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase): connection='upgrade', upgrade='websocket', sec_websocket_accept=b'', + sec_websocket_extensions='permessage-deflate' if "permessage-deflate" in request.headers.values() else '' ), content=b'', ) @@ -80,7 +82,7 @@ class _WebSocketTestBase: if self.client: self.client.close() - def setup_connection(self): + def setup_connection(self, extension=False): self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) self.client.connect() @@ -115,6 +117,7 @@ class _WebSocketTestBase: upgrade="websocket", sec_websocket_version="13", sec_websocket_key="1234", + sec_websocket_extensions="permessage-deflate" if extension else "" ), content=b'') self.client.wfile.write(http.http1.assemble_request(request)) @@ -145,11 +148,11 @@ class TestSimple(_WebSocketTest): wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() @pytest.mark.parametrize('streaming', [True, False]) @@ -183,17 +186,40 @@ class TestSimple(_WebSocketTest): assert isinstance(self.master.state.flows[0], HTTPFlow) assert isinstance(self.master.state.flows[1], WebSocketFlow) assert len(self.master.state.flows[1].messages) == 5 - assert self.master.state.flows[1].messages[0].content == b'server-foobar' + assert self.master.state.flows[1].messages[0].content == 'server-foobar' assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[1].content == b'self.client-foobar' + assert self.master.state.flows[1].messages[1].content == 'self.client-foobar' assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[2].content == b'self.client-foobar' + assert self.master.state.flows[1].messages[2].content == 'self.client-foobar' assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY + def test_change_payload(self): + class Addon: + def websocket_message(self, f): + f.messages[-1].content = "foo" + + self.master.addons.add(Addon()) + self.setup_connection() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + class TestSimpleTLS(_WebSocketTest): ssl = True @@ -204,7 +230,7 @@ class TestSimpleTLS(_WebSocketTest): wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() def test_simple_tls(self): @@ -237,19 +263,21 @@ class TestPing(_WebSocketTest): wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done'))) wfile.flush() + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + wfile.flush() + websockets.Frame.from_file(rfile) + def test_ping(self): self.setup_connection() frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' - - self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + websockets.Frame.from_file(self.client.rfile) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'' # We don't send payload to other end - frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == websockets.OPCODE.PONG - assert frame.payload == b'done' + assert self.master.has_log("Pong Received from server", "info") class TestPong(_WebSocketTest): @@ -258,11 +286,15 @@ class TestPong(_WebSocketTest): def handle_websockets(cls, rfile, wfile): frame = websockets.Frame.from_file(rfile) assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' + assert frame.payload == b'' wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) wfile.flush() + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + wfile.flush() + websockets.Frame.from_file(rfile) + def test_pong(self): self.setup_connection() @@ -270,8 +302,13 @@ class TestPong(_WebSocketTest): self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) + websockets.Frame.from_file(self.client.rfile) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() + assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' + assert self.master.has_log("Pong Received from server", "info") class TestClose(_WebSocketTest): @@ -279,7 +316,7 @@ class TestClose(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) wfile.flush() @@ -329,8 +366,9 @@ class TestInvalidFrame(_WebSocketTest): # with pytest.raises(exceptions.TcpDisconnect): frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == 15 - assert frame.payload == b'foobar' + code, = struct.unpack('!H', frame.payload[:2]) + assert code == 1002 + assert frame.payload[2:].startswith(b'Invalid opcode') class TestStreaming(_WebSocketTest): @@ -360,3 +398,51 @@ class TestStreaming(_WebSocketTest): assert frame assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received + + +class TestExtension(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00') + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.rsv1 + wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00') + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.rsv1 + wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00') + wfile.flush() + + def test_extension(self): + self.setup_connection(True) + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v') + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c') + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + assert len(self.master.state.flows[1].messages) == 5 + assert self.master.state.flows[1].messages[0].content == 'server-foobar' + assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[1].content == 'client-foobar' + assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[2].content == 'client-foobar' + assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY + assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY From 3cb459d56daeae8fd2b923c27f39ca5595a50e7b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 17 Aug 2017 10:18:05 +0200 Subject: [PATCH 3/5] docs++: add individual protocol pages --- docs/features/passthrough.rst | 4 ++-- docs/index.rst | 10 +++++++++- docs/protocols/http1.rst | 15 +++++++++++++++ docs/protocols/http2.rst | 16 ++++++++++++++++ docs/{features => protocols}/tcpproxy.rst | 6 +++--- docs/protocols/websocket.rst | 17 +++++++++++++++++ docs/scripting/events.rst | 2 +- 7 files changed, 63 insertions(+), 7 deletions(-) create mode 100644 docs/protocols/http1.rst create mode 100644 docs/protocols/http2.rst rename docs/{features => protocols}/tcpproxy.rst (93%) create mode 100644 docs/protocols/websocket.rst diff --git a/docs/features/passthrough.rst b/docs/features/passthrough.rst index 00462e9d9..dbaf35061 100644 --- a/docs/features/passthrough.rst +++ b/docs/features/passthrough.rst @@ -13,7 +13,7 @@ mechanism: away. Note that mitmproxy's "Limit" option is often the better alternative here, as it is not affected by the limitations listed below. -If you want to peek into (SSL-protected) non-HTTP connections, check out the :ref:`tcpproxy` +If you want to peek into (SSL-protected) non-HTTP connections, check out the :ref:`tcp_proxy` feature. If you want to ignore traffic from mitmproxy's processing because of large response bodies, take a look at the :ref:`streaming` feature. @@ -88,7 +88,7 @@ Here are some other examples for ignore patterns: .. seealso:: - - :ref:`tcpproxy` + - :ref:`tcp_proxy` - :ref:`streaming` - mitmproxy's "Limit" feature diff --git a/docs/index.rst b/docs/index.rst index 7cf593ff2..8dba4d04d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,6 +20,15 @@ mitmweb config +.. toctree:: + :hidden: + :caption: Protocols + + protocols/http1 + protocols/http2 + protocols/websocket + protocols/tcpproxy + .. toctree:: :hidden: :caption: Features @@ -36,7 +45,6 @@ features/streaming features/socksproxy features/sticky - features/tcpproxy features/upstreamproxy features/upstreamcerts diff --git a/docs/protocols/http1.rst b/docs/protocols/http1.rst new file mode 100644 index 000000000..21e68785e --- /dev/null +++ b/docs/protocols/http1.rst @@ -0,0 +1,15 @@ +.. _http1_protocol: + +HTTP/1.0 and HTTP/1.1 +=========================== + +.. seealso:: + + - `RFC7230: HTTP/1.1: Message Syntax and Routing `_ + - `RFC7231: HTTP/1.1: Semantics and Content `_ + +HTTP/1.0 and HTTP/1.1 support in mitmproxy is based on our custom HTTP stack, +which takes care of all semantics and on-the-wire parsing/serialization tasks. + +mitmproxy currently does not support HTTP trailers - but if you want to send +us a PR, we promise to take look! diff --git a/docs/protocols/http2.rst b/docs/protocols/http2.rst new file mode 100644 index 000000000..b3268ae5b --- /dev/null +++ b/docs/protocols/http2.rst @@ -0,0 +1,16 @@ +.. _http2_protocol: + +HTTP/2 +====== + +.. seealso:: + + - `RFC7540: Hypertext Transfer Protocol Version 2 (HTTP/2) `_ + +HTTP/2 support in mitmproxy is based on the amazing work by the python-hyper +community with the `hyper-h2 `_ +project. It fully encapsulates the internal state of HTTP/2 connections and +provides an easy-to-use event-based API. + +mitmproxy currently does not support HTTP/2 trailers - but if you want to send +us a PR, we promise to take look! diff --git a/docs/features/tcpproxy.rst b/docs/protocols/tcpproxy.rst similarity index 93% rename from docs/features/tcpproxy.rst rename to docs/protocols/tcpproxy.rst index cba374e3d..772485732 100644 --- a/docs/features/tcpproxy.rst +++ b/docs/protocols/tcpproxy.rst @@ -1,7 +1,7 @@ -.. _tcpproxy: +.. _tcp_proxy: -TCP Proxy -========= +TCP Proxy / Fallback +==================== In case mitmproxy does not handle a specific protocol, you can exempt hostnames from processing, so that mitmproxy acts as a generic TCP forwarder. diff --git a/docs/protocols/websocket.rst b/docs/protocols/websocket.rst new file mode 100644 index 000000000..85cff3aca --- /dev/null +++ b/docs/protocols/websocket.rst @@ -0,0 +1,17 @@ +.. _websocket_protocol: + +WebSocket +========= + +.. seealso:: + + - `RFC6455: The WebSocket Protocol `_ + - `RFC7692: Compression Extensions for WebSocket `_ + +WebSocket support in mitmproxy is based on the amazing work by the python-hyper +community with the `wsproto `_ +project. It fully encapsulates WebSocket frames/messages/connections and +provides an easy-to-use event-based API. + +mitmproxy fully supports the compression extension for WebSocket messages, +provided by wsproto. diff --git a/docs/scripting/events.rst b/docs/scripting/events.rst index 8f9463ffe..9e84dacfe 100644 --- a/docs/scripting/events.rst +++ b/docs/scripting/events.rst @@ -211,7 +211,7 @@ TCP Events ---------- These events are called only if the connection is in :ref:`TCP mode -`. So, for instance, TCP events are not called for ordinary HTTP/S +`. So, for instance, TCP events are not called for ordinary HTTP/S connections. .. list-table:: From 70e1409261adfd165b8473f1d21aa760023795d7 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 17 Aug 2017 10:38:02 +0200 Subject: [PATCH 4/5] docs++: add websocket PING/PONG --- docs/protocols/websocket.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/protocols/websocket.rst b/docs/protocols/websocket.rst index 85cff3aca..8a7e807f7 100644 --- a/docs/protocols/websocket.rst +++ b/docs/protocols/websocket.rst @@ -15,3 +15,8 @@ provides an easy-to-use event-based API. mitmproxy fully supports the compression extension for WebSocket messages, provided by wsproto. + +If an endpoint sends a PING to mitmproxy, a PONG will be sent back immediately +(with the same payload if present). To keep the other connection alive, a new +PING (without a payload) is sent to the other endpoint. Unsolicited PONG's are +not forwarded. All PING's and PONG's are logged (with payload if present). From f5fafbfcb56bbc3fb7cca7ed32dd7b3b41c39e83 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 12 Dec 2017 21:47:24 +0100 Subject: [PATCH 5/5] vendoring of wsproto https://github.com/python-hyper/wsproto.git commit 5ea2da61266796666f5de6461aaae22e6b00deba --- mitmproxy/contrib/wsproto/compat.py | 20 + mitmproxy/contrib/wsproto/connection.py | 477 ++++++++++++++++ mitmproxy/contrib/wsproto/events.py | 81 +++ mitmproxy/contrib/wsproto/extensions.py | 257 +++++++++ mitmproxy/contrib/wsproto/frame_protocol.py | 579 ++++++++++++++++++++ mitmproxy/proxy/protocol/websocket.py | 8 +- setup.py | 1 + 7 files changed, 1419 insertions(+), 4 deletions(-) create mode 100644 mitmproxy/contrib/wsproto/compat.py create mode 100644 mitmproxy/contrib/wsproto/connection.py create mode 100644 mitmproxy/contrib/wsproto/events.py create mode 100644 mitmproxy/contrib/wsproto/extensions.py create mode 100644 mitmproxy/contrib/wsproto/frame_protocol.py diff --git a/mitmproxy/contrib/wsproto/compat.py b/mitmproxy/contrib/wsproto/compat.py new file mode 100644 index 000000000..1911f83cf --- /dev/null +++ b/mitmproxy/contrib/wsproto/compat.py @@ -0,0 +1,20 @@ +# flake8: noqa + +import sys + + +PY2 = sys.version_info.major == 2 +PY3 = sys.version_info.major == 3 + + +if PY3: + unicode = str + + def Utf8Validator(): + return None +else: + unicode = unicode + try: + from wsaccel.utf8validator import Utf8Validator + except ImportError: + from .utf8validator import Utf8Validator diff --git a/mitmproxy/contrib/wsproto/connection.py b/mitmproxy/contrib/wsproto/connection.py new file mode 100644 index 000000000..f994cd3ab --- /dev/null +++ b/mitmproxy/contrib/wsproto/connection.py @@ -0,0 +1,477 @@ +# -*- coding: utf-8 -*- +""" +wsproto/connection +~~~~~~~~~~~~~~ + +An implementation of a WebSocket connection. +""" + +import os +import base64 +import hashlib +from collections import deque + +from enum import Enum + +import h11 + +from .events import ( + ConnectionRequested, ConnectionEstablished, ConnectionClosed, + ConnectionFailed, TextReceived, BytesReceived, PingReceived, PongReceived +) +from .frame_protocol import FrameProtocol, ParseFailed, CloseReason, Opcode + + +# RFC6455, Section 1.3 - Opening Handshake +ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +class ConnectionState(Enum): + """ + RFC 6455, Section 4 - Opening Handshake + """ + CONNECTING = 0 + OPEN = 1 + CLOSING = 2 + CLOSED = 3 + + +class ConnectionType(Enum): + CLIENT = 1 + SERVER = 2 + + +CLIENT = ConnectionType.CLIENT +SERVER = ConnectionType.SERVER + + +# Some convenience utilities for working with HTTP headers +def _normed_header_dict(h11_headers): + # This mangles Set-Cookie headers. But it happens that we don't care about + # any of those, so it's OK. For every other HTTP header, if there are + # multiple instances then you're allowed to join them together with + # commas. + name_to_values = {} + for name, value in h11_headers: + name_to_values.setdefault(name, []).append(value) + name_to_normed_value = {} + for name, values in name_to_values.items(): + name_to_normed_value[name] = b", ".join(values) + return name_to_normed_value + + +# We use this for parsing the proposed protocol list, and for parsing the +# proposed and accepted extension lists. For the proposed protocol list it's +# fine, because the ABNF is just 1#token. But for the extension lists, it's +# wrong, because those can contain quoted strings, which can in turn contain +# commas. XX FIXME +def _split_comma_header(value): + return [piece.decode('ascii').strip() for piece in value.split(b',')] + + +class WSConnection(object): + """ + A low-level WebSocket connection object. + + This wraps two other protocol objects, an HTTP/1.1 protocol object used + to do the initial HTTP upgrade handshake and a WebSocket frame protocol + object used to exchange messages and other control frames. + + :param conn_type: Whether this object is on the client- or server-side of + a connection. To initialise as a client pass ``CLIENT`` otherwise + pass ``SERVER``. + :type conn_type: ``ConnectionType`` + + :param host: The hostname to pass to the server when acting as a client. + :type host: ``str`` + + :param resource: The resource (aka path) to pass to the server when acting + as a client. + :type resource: ``str`` + + :param extensions: A list of extensions to use on this connection. + Extensions should be instances of a subclass of + :class:`Extension `. + + :param subprotocols: A list of subprotocols to request when acting as a + client, ordered by preference. This has no impact on the connection + itself. + :type subprotocol: ``list`` of ``str`` + """ + + def __init__(self, conn_type, host=None, resource=None, extensions=None, + subprotocols=None): + self.client = conn_type is ConnectionType.CLIENT + + self.host = host + self.resource = resource + + self.subprotocols = subprotocols or [] + self.extensions = extensions or [] + + self.version = b'13' + + self._state = ConnectionState.CONNECTING + self._close_reason = None + + self._nonce = None + self._outgoing = b'' + self._events = deque() + self._proto = None + + if self.client: + self._upgrade_connection = h11.Connection(h11.CLIENT) + else: + self._upgrade_connection = h11.Connection(h11.SERVER) + + if self.client: + if self.host is None: + raise ValueError( + "Host must not be None for a client-side connection.") + if self.resource is None: + raise ValueError( + "Resource must not be None for a client-side connection.") + self.initiate_connection() + + def initiate_connection(self): + self._generate_nonce() + + headers = { + b"Host": self.host.encode('ascii'), + b"Upgrade": b'WebSocket', + b"Connection": b'Upgrade', + b"Sec-WebSocket-Key": self._nonce, + b"Sec-WebSocket-Version": self.version, + } + + if self.subprotocols: + headers[b"Sec-WebSocket-Protocol"] = ", ".join(self.subprotocols) + + if self.extensions: + offers = {e.name: e.offer(self) for e in self.extensions} + extensions = [] + for name, params in offers.items(): + if params is True: + extensions.append(name.encode('ascii')) + elif params: + # py34 annoyance: doesn't support bytestring formatting + extensions.append(('%s; %s' % (name, params)) + .encode("ascii")) + if extensions: + headers[b'Sec-WebSocket-Extensions'] = b', '.join(extensions) + + upgrade = h11.Request(method=b'GET', target=self.resource, + headers=headers.items()) + self._outgoing += self._upgrade_connection.send(upgrade) + + def send_data(self, payload, final=True): + """ + Send a message or part of a message to the remote peer. + + If ``final`` is ``False`` it indicates that this is part of a longer + message. If ``final`` is ``True`` it indicates that this is either a + self-contained message or the last part of a longer message. + + If ``payload`` is of type ``bytes`` then the message is flagged as + being binary If it is of type ``str`` encoded as UTF-8 and sent as + text. + + :param payload: The message body to send. + :type payload: ``bytes`` or ``str`` + + :param final: Whether there are more parts to this message to be sent. + :type final: ``bool`` + """ + + self._outgoing += self._proto.send_data(payload, final) + + def close(self, code=CloseReason.NORMAL_CLOSURE, reason=None): + self._outgoing += self._proto.close(code, reason) + self._state = ConnectionState.CLOSING + + @property + def closed(self): + return self._state is ConnectionState.CLOSED + + def bytes_to_send(self, amount=None): + """ + Return any data that is to be sent to the remote peer. + + :param amount: (optional) The maximum number of bytes to be provided. + If ``None`` or not provided it will return all available bytes. + :type amount: ``int`` + """ + + if amount is None: + data = self._outgoing + self._outgoing = b'' + else: + data = self._outgoing[:amount] + self._outgoing = self._outgoing[amount:] + + return data + + def receive_bytes(self, data): + """ + Pass some received bytes to the connection for processing. + + :param data: The data received from the remote peer. + :type data: ``bytes`` + """ + + if data is None and self._state is ConnectionState.OPEN: + # "If _The WebSocket Connection is Closed_ and no Close control + # frame was received by the endpoint (such as could occur if the + # underlying transport connection is lost), _The WebSocket + # Connection Close Code_ is considered to be 1006." + self._events.append(ConnectionClosed(CloseReason.ABNORMAL_CLOSURE)) + self._state = ConnectionState.CLOSED + return + elif data is None: + self._state = ConnectionState.CLOSED + return + + if self._state is ConnectionState.CONNECTING: + event, data = self._process_upgrade(data) + if event is not None: + self._events.append(event) + + if self._state is ConnectionState.OPEN: + self._proto.receive_bytes(data) + + def _process_upgrade(self, data): + self._upgrade_connection.receive_data(data) + while True: + try: + event = self._upgrade_connection.next_event() + except h11.RemoteProtocolError: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad HTTP message"), b'' + if event is h11.NEED_DATA: + break + elif self.client and isinstance(event, (h11.InformationalResponse, + h11.Response)): + data = self._upgrade_connection.trailing_data[0] + return self._establish_client_connection(event), data + elif not self.client and isinstance(event, h11.Request): + return self._process_connection_request(event), None + else: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad HTTP message"), b'' + + self._incoming = b'' + return None, None + + def events(self): + """ + Return a generator that provides any events that have been generated + by protocol activity. + + :returns: generator + """ + + while self._events: + yield self._events.popleft() + + if self._proto is None: + return + + try: + for frame in self._proto.received_frames(): + if frame.opcode is Opcode.PING: + assert frame.frame_finished and frame.message_finished + self._outgoing += self._proto.pong(frame.payload) + yield PingReceived(frame.payload) + + elif frame.opcode is Opcode.PONG: + assert frame.frame_finished and frame.message_finished + yield PongReceived(frame.payload) + + elif frame.opcode is Opcode.CLOSE: + code, reason = frame.payload + self.close(code, reason) + yield ConnectionClosed(code, reason) + + elif frame.opcode is Opcode.TEXT: + yield TextReceived(frame.payload, + frame.frame_finished, + frame.message_finished) + + elif frame.opcode is Opcode.BINARY: + yield BytesReceived(frame.payload, + frame.frame_finished, + frame.message_finished) + except ParseFailed as exc: + # XX FIXME: apparently autobahn intentionally deviates from the + # spec in that on protocol errors it just closes the connection + # rather than trying to send a CLOSE frame. Investigate whether we + # should do the same. + self.close(code=exc.code, reason=str(exc)) + yield ConnectionClosed(exc.code, reason=str(exc)) + + def _generate_nonce(self): + # os.urandom may be overkill for this use case, but I don't think this + # is a bottleneck, and better safe than sorry... + self._nonce = base64.b64encode(os.urandom(16)) + + def _generate_accept_token(self, token): + accept_token = token + ACCEPT_GUID + accept_token = hashlib.sha1(accept_token).digest() + return base64.b64encode(accept_token) + + def _establish_client_connection(self, event): + if event.status_code != 101: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad status code from server") + headers = _normed_header_dict(event.headers) + if headers[b'connection'].lower() != b'upgrade': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Connection: Upgrade header") + if headers[b'upgrade'].lower() != b'websocket': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Upgrade: WebSocket header") + + accept_token = self._generate_accept_token(self._nonce) + if headers[b'sec-websocket-accept'] != accept_token: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad accept token") + + subprotocol = headers.get(b'sec-websocket-protocol', None) + if subprotocol is not None: + subprotocol = subprotocol.decode('ascii') + if subprotocol not in self.subprotocols: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "unrecognized subprotocol {!r}" + .format(subprotocol)) + + extensions = headers.get(b'sec-websocket-extensions', None) + if extensions: + accepts = _split_comma_header(extensions) + + for accept in accepts: + name = accept.split(';', 1)[0].strip() + for extension in self.extensions: + if extension.name == name: + extension.finalize(self, accept) + break + else: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "unrecognized extension {!r}" + .format(name)) + + self._proto = FrameProtocol(self.client, self.extensions) + self._state = ConnectionState.OPEN + return ConnectionEstablished(subprotocol, extensions) + + def _process_connection_request(self, event): + if event.method != b'GET': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Request method must be GET") + headers = _normed_header_dict(event.headers) + if headers[b'connection'].lower() != b'upgrade': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Connection: Upgrade header") + if headers[b'upgrade'].lower() != b'websocket': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Upgrade: WebSocket header") + + if b'sec-websocket-version' not in headers: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Sec-WebSocket-Version header") + # XX FIXME: need to check Sec-Websocket-Version, and respond with a + # 400 if it's not what we expect + + if b'sec-websocket-protocol' in headers: + proposed_subprotocols = _split_comma_header( + headers[b'sec-websocket-protocol']) + else: + proposed_subprotocols = [] + + if b'sec-websocket-key' not in headers: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Sec-WebSocket-Key header") + + return ConnectionRequested(proposed_subprotocols, event) + + def _extension_accept(self, extensions_header): + accepts = {} + offers = _split_comma_header(extensions_header) + + for offer in offers: + name = offer.split(';', 1)[0].strip() + for extension in self.extensions: + if extension.name == name: + accept = extension.accept(self, offer) + if accept is True: + accepts[extension.name] = True + elif accept is not False and accept is not None: + accepts[extension.name] = accept.encode('ascii') + + if accepts: + extensions = [] + for name, params in accepts.items(): + if params is True: + extensions.append(name.encode('ascii')) + else: + # py34 annoyance: doesn't support bytestring formatting + params = params.decode("ascii") + extensions.append(('%s; %s' % (name, params)) + .encode("ascii")) + return b', '.join(extensions) + + return None + + def accept(self, event, subprotocol=None): + request = event.h11request + request_headers = _normed_header_dict(request.headers) + + nonce = request_headers[b'sec-websocket-key'] + accept_token = self._generate_accept_token(nonce) + + headers = { + b"Upgrade": b'WebSocket', + b"Connection": b'Upgrade', + b"Sec-WebSocket-Accept": accept_token, + } + + if subprotocol is not None: + if subprotocol not in event.proposed_subprotocols: + raise ValueError( + "unexpected subprotocol {!r}".format(subprotocol)) + headers[b'Sec-WebSocket-Protocol'] = subprotocol + + extensions = request_headers.get(b'sec-websocket-extensions', None) + if extensions: + accepts = self._extension_accept(extensions) + if accepts: + headers[b"Sec-WebSocket-Extensions"] = accepts + + response = h11.InformationalResponse(status_code=101, + headers=headers.items()) + self._outgoing += self._upgrade_connection.send(response) + self._proto = FrameProtocol(self.client, self.extensions) + self._state = ConnectionState.OPEN + + def ping(self, payload=None): + """ + Send a PING message to the peer. + + :param payload: an optional payload to send with the message + """ + + payload = bytes(payload or b'') + self._outgoing += self._proto.ping(payload) + + def pong(self, payload=None): + """ + Send a PONG message to the peer. + + This method can be used to send an unsolicted PONG to the peer. + It is not needed otherwise since every received PING causes a + corresponding PONG to be sent automatically. + + :param payload: an optional payload to send with the message + """ + + payload = bytes(payload or b'') + self._outgoing += self._proto.pong(payload) diff --git a/mitmproxy/contrib/wsproto/events.py b/mitmproxy/contrib/wsproto/events.py new file mode 100644 index 000000000..73ce27aac --- /dev/null +++ b/mitmproxy/contrib/wsproto/events.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +""" +wsproto/events +~~~~~~~~~~ + +Events that result from processing data on a WebSocket connection. +""" + + +class ConnectionRequested(object): + def __init__(self, proposed_subprotocols, h11request): + self.proposed_subprotocols = proposed_subprotocols + self.h11request = h11request + + def __repr__(self): + path = self.h11request.target + + headers = dict(self.h11request.headers) + host = headers[b'host'] + version = headers[b'sec-websocket-version'] + subprotocol = headers.get(b'sec-websocket-protocol', None) + extensions = [] + + fmt = '<%s host=%s path=%s version=%s subprotocol=%r extensions=%r>' + return fmt % (self.__class__.__name__, host, path, version, + subprotocol, extensions) + + +class ConnectionEstablished(object): + def __init__(self, subprotocol=None, extensions=None): + self.subprotocol = subprotocol + self.extensions = extensions + if self.extensions is None: + self.extensions = [] + + def __repr__(self): + return '' % \ + (self.subprotocol, self.extensions) + + +class ConnectionClosed(object): + def __init__(self, code, reason=None): + self.code = code + self.reason = reason + + def __repr__(self): + return '<%s code=%r reason="%s">' % (self.__class__.__name__, + self.code, self.reason) + + +class ConnectionFailed(ConnectionClosed): + pass + + +class DataReceived(object): + def __init__(self, data, frame_finished, message_finished): + self.data = data + # This has no semantic content, but is provided just in case some + # weird edge case user wants to be able to reconstruct the + # fragmentation pattern of the original stream. You don't want it: + self.frame_finished = frame_finished + # This is the field that you almost certainly want: + self.message_finished = message_finished + + +class TextReceived(DataReceived): + pass + + +class BytesReceived(DataReceived): + pass + + +class PingReceived(object): + def __init__(self, payload): + self.payload = payload + + +class PongReceived(object): + def __init__(self, payload): + self.payload = payload diff --git a/mitmproxy/contrib/wsproto/extensions.py b/mitmproxy/contrib/wsproto/extensions.py new file mode 100644 index 000000000..f7cf4fb61 --- /dev/null +++ b/mitmproxy/contrib/wsproto/extensions.py @@ -0,0 +1,257 @@ +# -*- coding: utf-8 -*- +""" +wsproto/extensions +~~~~~~~~~~~~~~ + +WebSocket extensions. +""" + +import zlib + +from .frame_protocol import CloseReason, Opcode, RsvBits + + +class Extension(object): + name = None + + def enabled(self): + return False + + def offer(self, connection): + pass + + def accept(self, connection, offer): + pass + + def finalize(self, connection, offer): + pass + + def frame_inbound_header(self, proto, opcode, rsv, payload_length): + return RsvBits(False, False, False) + + def frame_inbound_payload_data(self, proto, data): + return data + + def frame_inbound_complete(self, proto, fin): + pass + + def frame_outbound(self, proto, opcode, rsv, data, fin): + return (rsv, data) + + +class PerMessageDeflate(Extension): + name = 'permessage-deflate' + + DEFAULT_CLIENT_MAX_WINDOW_BITS = 15 + DEFAULT_SERVER_MAX_WINDOW_BITS = 15 + + def __init__(self, client_no_context_takeover=False, + client_max_window_bits=None, server_no_context_takeover=False, + server_max_window_bits=None): + self.client_no_context_takeover = client_no_context_takeover + if client_max_window_bits is None: + client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS + self.client_max_window_bits = client_max_window_bits + self.server_no_context_takeover = server_no_context_takeover + if server_max_window_bits is None: + server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS + self.server_max_window_bits = server_max_window_bits + + self._compressor = None + self._decompressor = None + # This refers to the current frame + self._inbound_is_compressible = None + # This refers to the ongoing message (which might span multiple + # frames). Only the first frame in a fragmented message is flagged for + # compression, so this carries that bit forward. + self._inbound_compressed = None + + self._enabled = False + + def _compressible_opcode(self, opcode): + return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION) + + def enabled(self): + return self._enabled + + def offer(self, connection): + parameters = [ + 'client_max_window_bits=%d' % self.client_max_window_bits, + 'server_max_window_bits=%d' % self.server_max_window_bits, + ] + + if self.client_no_context_takeover: + parameters.append('client_no_context_takeover') + if self.server_no_context_takeover: + parameters.append('server_no_context_takeover') + + return '; '.join(parameters) + + def finalize(self, connection, offer): + bits = [b.strip() for b in offer.split(';')] + for bit in bits[1:]: + if bit.startswith('client_no_context_takeover'): + self.client_no_context_takeover = True + elif bit.startswith('server_no_context_takeover'): + self.server_no_context_takeover = True + elif bit.startswith('client_max_window_bits'): + self.client_max_window_bits = int(bit.split('=', 1)[1].strip()) + elif bit.startswith('server_max_window_bits'): + self.server_max_window_bits = int(bit.split('=', 1)[1].strip()) + + self._enabled = True + + def _parse_params(self, params): + client_max_window_bits = None + server_max_window_bits = None + + bits = [b.strip() for b in params.split(';')] + for bit in bits[1:]: + if bit.startswith('client_no_context_takeover'): + self.client_no_context_takeover = True + elif bit.startswith('server_no_context_takeover'): + self.server_no_context_takeover = True + elif bit.startswith('client_max_window_bits'): + if '=' in bit: + client_max_window_bits = int(bit.split('=', 1)[1].strip()) + else: + client_max_window_bits = self.client_max_window_bits + elif bit.startswith('server_max_window_bits'): + if '=' in bit: + server_max_window_bits = int(bit.split('=', 1)[1].strip()) + else: + server_max_window_bits = self.server_max_window_bits + + return client_max_window_bits, server_max_window_bits + + def accept(self, connection, offer): + client_max_window_bits, server_max_window_bits = \ + self._parse_params(offer) + + self._enabled = True + + parameters = [] + + if self.client_no_context_takeover: + parameters.append('client_no_context_takeover') + if client_max_window_bits is not None: + parameters.append('client_max_window_bits=%d' % + client_max_window_bits) + self.client_max_window_bits = client_max_window_bits + if self.server_no_context_takeover: + parameters.append('server_no_context_takeover') + if server_max_window_bits is not None: + parameters.append('server_max_window_bits=%d' % + server_max_window_bits) + self.server_max_window_bits = server_max_window_bits + + return '; '.join(parameters) + + def frame_inbound_header(self, proto, opcode, rsv, payload_length): + if rsv.rsv1 and opcode.iscontrol(): + return CloseReason.PROTOCOL_ERROR + elif rsv.rsv1 and opcode is Opcode.CONTINUATION: + return CloseReason.PROTOCOL_ERROR + + self._inbound_is_compressible = self._compressible_opcode(opcode) + + if self._inbound_compressed is None: + self._inbound_compressed = rsv.rsv1 + if self._inbound_compressed: + assert self._inbound_is_compressible + if proto.client: + bits = self.server_max_window_bits + else: + bits = self.client_max_window_bits + if self._decompressor is None: + self._decompressor = zlib.decompressobj(-int(bits)) + + return RsvBits(True, False, False) + + def frame_inbound_payload_data(self, proto, data): + if not self._inbound_compressed or not self._inbound_is_compressible: + return data + + try: + return self._decompressor.decompress(bytes(data)) + except zlib.error: + return CloseReason.INVALID_FRAME_PAYLOAD_DATA + + def frame_inbound_complete(self, proto, fin): + if not fin: + return + elif not self._inbound_is_compressible: + return + elif not self._inbound_compressed: + return + + try: + data = self._decompressor.decompress(b'\x00\x00\xff\xff') + data += self._decompressor.flush() + except zlib.error: + return CloseReason.INVALID_FRAME_PAYLOAD_DATA + + if proto.client: + no_context_takeover = self.server_no_context_takeover + else: + no_context_takeover = self.client_no_context_takeover + + if no_context_takeover: + self._decompressor = None + + self._inbound_compressed = None + + return data + + def frame_outbound(self, proto, opcode, rsv, data, fin): + if not self._compressible_opcode(opcode): + return (rsv, data) + + if opcode is not Opcode.CONTINUATION: + rsv = RsvBits(True, *rsv[1:]) + + if self._compressor is None: + assert opcode is not Opcode.CONTINUATION + if proto.client: + bits = self.client_max_window_bits + else: + bits = self.server_max_window_bits + self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, + zlib.DEFLATED, -int(bits)) + + data = self._compressor.compress(bytes(data)) + + if fin: + data += self._compressor.flush(zlib.Z_SYNC_FLUSH) + data = data[:-4] + + if proto.client: + no_context_takeover = self.client_no_context_takeover + else: + no_context_takeover = self.server_no_context_takeover + + if no_context_takeover: + self._compressor = None + + return (rsv, data) + + def __repr__(self): + descr = ['client_max_window_bits=%d' % self.client_max_window_bits] + if self.client_no_context_takeover: + descr.append('client_no_context_takeover') + descr.append('server_max_window_bits=%d' % self.server_max_window_bits) + if self.server_no_context_takeover: + descr.append('server_no_context_takeover') + + descr = '; '.join(descr) + + return '<%s %s>' % (self.__class__.__name__, descr) + + +#: SUPPORTED_EXTENSIONS maps all supported extension names to their class. +#: This can be used to iterate all supported extensions of wsproto, instantiate +#: new extensions based on their name, or check if a given extension is +#: supported or not. +SUPPORTED_EXTENSIONS = { + PerMessageDeflate.name: PerMessageDeflate +} diff --git a/mitmproxy/contrib/wsproto/frame_protocol.py b/mitmproxy/contrib/wsproto/frame_protocol.py new file mode 100644 index 000000000..b95dceec2 --- /dev/null +++ b/mitmproxy/contrib/wsproto/frame_protocol.py @@ -0,0 +1,579 @@ +# -*- coding: utf-8 -*- +""" +wsproto/frame_protocol +~~~~~~~~~~~~~~ + +WebSocket frame protocol implementation. +""" + +import os +import itertools +import struct +from codecs import getincrementaldecoder +from collections import namedtuple + +from enum import Enum, IntEnum + +from .compat import unicode, Utf8Validator + +try: + from wsaccel.xormask import XorMaskerSimple +except ImportError: + class XorMaskerSimple: + def __init__(self, masking_key): + self._maskbytes = itertools.cycle(bytearray(masking_key)) + + def process(self, data): + maskbytes = self._maskbytes + return bytearray(b ^ next(maskbytes) for b in bytearray(data)) + + +class XorMaskerNull: + def process(self, data): + return data + + +# RFC6455, Section 5.2 - Base Framing Protocol + +# Payload length constants +PAYLOAD_LENGTH_TWO_BYTE = 126 +PAYLOAD_LENGTH_EIGHT_BYTE = 127 +MAX_PAYLOAD_NORMAL = 125 +MAX_PAYLOAD_TWO_BYTE = 2 ** 16 - 1 +MAX_PAYLOAD_EIGHT_BYTE = 2 ** 64 - 1 +MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE + +# MASK and PAYLOAD LEN are packed into a byte +MASK_MASK = 0x80 +PAYLOAD_LEN_MASK = 0x7f + +# FIN, RSV[123] and OPCODE are packed into a single byte +FIN_MASK = 0x80 +RSV1_MASK = 0x40 +RSV2_MASK = 0x20 +RSV3_MASK = 0x10 +OPCODE_MASK = 0x0f + + +class Opcode(IntEnum): + """ + RFC 6455, Section 5.2 - Base Framing Protocol + """ + CONTINUATION = 0x0 + TEXT = 0x1 + BINARY = 0x2 + CLOSE = 0x8 + PING = 0x9 + PONG = 0xA + + def iscontrol(self): + return bool(self & 0x08) + + +class CloseReason(IntEnum): + """ + RFC 6455, Section 7.4.1 - Defined Status Codes + """ + NORMAL_CLOSURE = 1000 + GOING_AWAY = 1001 + PROTOCOL_ERROR = 1002 + UNSUPPORTED_DATA = 1003 + NO_STATUS_RCVD = 1005 + ABNORMAL_CLOSURE = 1006 + INVALID_FRAME_PAYLOAD_DATA = 1007 + POLICY_VIOLATION = 1008 + MESSAGE_TOO_BIG = 1009 + MANDATORY_EXT = 1010 + INTERNAL_ERROR = 1011 + SERVICE_RESTART = 1012 + TRY_AGAIN_LATER = 1013 + TLS_HANDSHAKE_FAILED = 1015 + + +# RFC 6455, Section 7.4.1 - Defined Status Codes +LOCAL_ONLY_CLOSE_REASONS = ( + CloseReason.NO_STATUS_RCVD, + CloseReason.ABNORMAL_CLOSURE, + CloseReason.TLS_HANDSHAKE_FAILED, +) + + +# RFC 6455, Section 7.4.2 - Status Code Ranges +MIN_CLOSE_REASON = 1000 +MIN_PROTOCOL_CLOSE_REASON = 1000 +MAX_PROTOCOL_CLOSE_REASON = 2999 +MIN_LIBRARY_CLOSE_REASON = 3000 +MAX_LIBRARY_CLOSE_REASON = 3999 +MIN_PRIVATE_CLOSE_REASON = 4000 +MAX_PRIVATE_CLOSE_REASON = 4999 +MAX_CLOSE_REASON = 4999 + + +NULL_MASK = struct.pack("!I", 0) + + +class ParseFailed(Exception): + def __init__(self, msg, code=CloseReason.PROTOCOL_ERROR): + super(ParseFailed, self).__init__(msg) + self.code = code + + +Header = namedtuple("Header", "fin rsv opcode payload_len masking_key".split()) + + +Frame = namedtuple("Frame", + "opcode payload frame_finished message_finished".split()) + + +RsvBits = namedtuple("RsvBits", "rsv1 rsv2 rsv3".split()) + + +def _truncate_utf8(data, nbytes): + if len(data) <= nbytes: + return data + + # Truncate + data = data[:nbytes] + # But we might have cut a codepoint in half, in which case we want to + # discard the partial character so the data is at least + # well-formed. This is a little inefficient since it processes the + # whole message twice when in theory we could just peek at the last + # few characters, but since this is only used for close messages (max + # length = 125 bytes) it really doesn't matter. + data = data.decode("utf-8", errors="ignore").encode("utf-8") + return data + + +class Buffer(object): + def __init__(self, initial_bytes=None): + self.buffer = bytearray() + self.bytes_used = 0 + if initial_bytes: + self.feed(initial_bytes) + + def feed(self, new_bytes): + self.buffer += new_bytes + + def consume_at_most(self, nbytes): + if not nbytes: + return bytearray() + + data = self.buffer[self.bytes_used:self.bytes_used + nbytes] + self.bytes_used += len(data) + return data + + def consume_exactly(self, nbytes): + if len(self.buffer) - self.bytes_used < nbytes: + return None + + return self.consume_at_most(nbytes) + + def commit(self): + # In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic + del self.buffer[:self.bytes_used] + self.bytes_used = 0 + + def rollback(self): + self.bytes_used = 0 + + def __len__(self): + return len(self.buffer) + + +class MessageDecoder(object): + def __init__(self): + self.opcode = None + self.validator = None + self.decoder = None + + def process_frame(self, frame): + assert not frame.opcode.iscontrol() + + if self.opcode is None: + if frame.opcode is Opcode.CONTINUATION: + raise ParseFailed("unexpected CONTINUATION") + self.opcode = frame.opcode + elif frame.opcode is not Opcode.CONTINUATION: + raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode) + + if frame.opcode is Opcode.TEXT: + self.validator = Utf8Validator() + self.decoder = getincrementaldecoder("utf-8")() + + finished = frame.frame_finished and frame.message_finished + + if self.decoder is not None: + data = self.decode_payload(frame.payload, finished) + else: + data = frame.payload + + frame = Frame(self.opcode, data, frame.frame_finished, finished) + + if finished: + self.opcode = None + self.decoder = None + + return frame + + def decode_payload(self, data, finished): + if self.validator is not None: + results = self.validator.validate(bytes(data)) + if not results[0] or (finished and not results[1]): + raise ParseFailed(u'encountered invalid UTF-8 while processing' + ' text message at payload octet index %d' % + results[3], + CloseReason.INVALID_FRAME_PAYLOAD_DATA) + + try: + return self.decoder.decode(data, finished) + except UnicodeDecodeError as exc: + raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA) + + +class FrameDecoder(object): + def __init__(self, client, extensions=None): + self.client = client + self.extensions = extensions or [] + + self.buffer = Buffer() + + self.header = None + self.effective_opcode = None + self.masker = None + self.payload_required = 0 + self.payload_consumed = 0 + + def receive_bytes(self, data): + self.buffer.feed(data) + + def process_buffer(self): + if not self.header: + if not self.parse_header(): + return None + + if len(self.buffer) < self.payload_required: + return None + + payload_remaining = self.header.payload_len - self.payload_consumed + payload = self.buffer.consume_at_most(payload_remaining) + if not payload and self.header.payload_len > 0: + return None + self.buffer.commit() + + self.payload_consumed += len(payload) + finished = self.payload_consumed == self.header.payload_len + + payload = self.masker.process(payload) + + for extension in self.extensions: + payload = extension.frame_inbound_payload_data(self, payload) + if isinstance(payload, CloseReason): + raise ParseFailed("error in extension", payload) + + if finished: + final = bytearray() + for extension in self.extensions: + result = extension.frame_inbound_complete(self, + self.header.fin) + if isinstance(result, CloseReason): + raise ParseFailed("error in extension", result) + if result is not None: + final += result + payload += final + + frame = Frame(self.effective_opcode, payload, finished, + self.header.fin) + + if finished: + self.header = None + self.effective_opcode = None + self.masker = None + else: + self.effective_opcode = Opcode.CONTINUATION + + return frame + + def parse_header(self): + data = self.buffer.consume_exactly(2) + if data is None: + self.buffer.rollback() + return False + + fin = bool(data[0] & FIN_MASK) + rsv = RsvBits(bool(data[0] & RSV1_MASK), + bool(data[0] & RSV2_MASK), + bool(data[0] & RSV3_MASK)) + opcode = data[0] & OPCODE_MASK + try: + opcode = Opcode(opcode) + except ValueError: + raise ParseFailed("Invalid opcode {:#x}".format(opcode)) + + if opcode.iscontrol() and not fin: + raise ParseFailed("Invalid attempt to fragment control frame") + + has_mask = bool(data[1] & MASK_MASK) + payload_len = data[1] & PAYLOAD_LEN_MASK + payload_len = self.parse_extended_payload_length(opcode, payload_len) + if payload_len is None: + self.buffer.rollback() + return False + + self.extension_processing(opcode, rsv, payload_len) + + if has_mask and self.client: + raise ParseFailed("client received unexpected masked frame") + if not has_mask and not self.client: + raise ParseFailed("server received unexpected unmasked frame") + if has_mask: + masking_key = self.buffer.consume_exactly(4) + if masking_key is None: + self.buffer.rollback() + return False + self.masker = XorMaskerSimple(masking_key) + else: + self.masker = XorMaskerNull() + + self.buffer.commit() + self.header = Header(fin, rsv, opcode, payload_len, None) + self.effective_opcode = self.header.opcode + if self.header.opcode.iscontrol(): + self.payload_required = payload_len + else: + self.payload_required = 0 + self.payload_consumed = 0 + return True + + def parse_extended_payload_length(self, opcode, payload_len): + if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL: + raise ParseFailed("Control frame with payload len > 125") + if payload_len == PAYLOAD_LENGTH_TWO_BYTE: + data = self.buffer.consume_exactly(2) + if data is None: + return None + (payload_len,) = struct.unpack("!H", data) + if payload_len <= MAX_PAYLOAD_NORMAL: + raise ParseFailed( + "Payload length used 2 bytes when 1 would have sufficed") + elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE: + data = self.buffer.consume_exactly(8) + if data is None: + return None + (payload_len,) = struct.unpack("!Q", data) + if payload_len <= MAX_PAYLOAD_TWO_BYTE: + raise ParseFailed( + "Payload length used 8 bytes when 2 would have sufficed") + if payload_len >> 63: + # I'm not sure why this is illegal, but that's what the RFC + # says, so... + raise ParseFailed("8-byte payload length with non-zero MSB") + + return payload_len + + def extension_processing(self, opcode, rsv, payload_len): + rsv_used = [False, False, False] + for extension in self.extensions: + result = extension.frame_inbound_header(self, opcode, rsv, + payload_len) + if isinstance(result, CloseReason): + raise ParseFailed("error in extension", result) + for bit, used in enumerate(result): + if used: + rsv_used[bit] = True + for expected, found in zip(rsv_used, rsv): + if found and not expected: + raise ParseFailed("Reserved bit set unexpectedly") + + +class FrameProtocol(object): + class State(Enum): + HEADER = 1 + PAYLOAD = 2 + FRAME_COMPLETE = 3 + FAILED = 4 + + def __init__(self, client, extensions): + self.client = client + self.extensions = [ext for ext in extensions if ext.enabled()] + + # Global state + self._frame_decoder = FrameDecoder(self.client, self.extensions) + self._message_decoder = MessageDecoder() + self._parse_more = self.parse_more_gen() + + self._outbound_opcode = None + + def _process_close(self, frame): + data = frame.payload + + if not data: + # "If this Close control frame contains no status code, _The + # WebSocket Connection Close Code_ is considered to be 1005" + data = (CloseReason.NO_STATUS_RCVD, "") + elif len(data) == 1: + raise ParseFailed("CLOSE with 1 byte payload") + else: + (code,) = struct.unpack("!H", data[:2]) + if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON: + raise ParseFailed("CLOSE with invalid code") + try: + code = CloseReason(code) + except ValueError: + pass + if code in LOCAL_ONLY_CLOSE_REASONS: + raise ParseFailed( + "remote CLOSE with local-only reason") + if not isinstance(code, CloseReason) and \ + code <= MAX_PROTOCOL_CLOSE_REASON: + raise ParseFailed( + "CLOSE with unknown reserved code") + validator = Utf8Validator() + if validator is not None: + results = validator.validate(bytes(data[2:])) + if not (results[0] and results[1]): + raise ParseFailed(u'encountered invalid UTF-8 while' + ' processing close message at payload' + ' octet index %d' % + results[3], + CloseReason.INVALID_FRAME_PAYLOAD_DATA) + try: + reason = data[2:].decode("utf-8") + except UnicodeDecodeError as exc: + raise ParseFailed( + "Error decoding CLOSE reason: " + str(exc), + CloseReason.INVALID_FRAME_PAYLOAD_DATA) + data = (code, reason) + + return Frame(frame.opcode, data, frame.frame_finished, + frame.message_finished) + + def parse_more_gen(self): + # Consume as much as we can from self._buffer, yielding events, and + # then yield None when we need more data. Or raise ParseFailed. + + # XX FIXME this should probably be refactored so that we never see + # disabled extensions in the first place... + self.extensions = [ext for ext in self.extensions if ext.enabled()] + closed = False + + while not closed: + frame = self._frame_decoder.process_buffer() + + if frame is not None: + if not frame.opcode.iscontrol(): + frame = self._message_decoder.process_frame(frame) + elif frame.opcode == Opcode.CLOSE: + frame = self._process_close(frame) + closed = True + + yield frame + + def receive_bytes(self, data): + self._frame_decoder.receive_bytes(data) + + def received_frames(self): + for event in self._parse_more: + if event is None: + break + else: + yield event + + def close(self, code=None, reason=None): + payload = bytearray() + if code is None and reason is not None: + raise TypeError("cannot specify a reason without a code") + if code in LOCAL_ONLY_CLOSE_REASONS: + code = CloseReason.NORMAL_CLOSURE + if code is not None: + payload += bytearray(struct.pack('!H', code)) + if reason is not None: + payload += _truncate_utf8(reason.encode('utf-8'), + MAX_PAYLOAD_NORMAL - 2) + + return self._serialize_frame(Opcode.CLOSE, payload) + + def ping(self, payload=b''): + return self._serialize_frame(Opcode.PING, payload) + + def pong(self, payload=b''): + return self._serialize_frame(Opcode.PONG, payload) + + def send_data(self, payload=b'', fin=True): + if isinstance(payload, (bytes, bytearray, memoryview)): + opcode = Opcode.BINARY + elif isinstance(payload, unicode): + opcode = Opcode.TEXT + payload = payload.encode('utf-8') + else: + raise ValueError('Must provide bytes or text') + + if self._outbound_opcode is None: + self._outbound_opcode = opcode + elif self._outbound_opcode is not opcode: + raise TypeError('Data type mismatch inside message') + else: + opcode = Opcode.CONTINUATION + + if fin: + self._outbound_opcode = None + + return self._serialize_frame(opcode, payload, fin) + + def _make_fin_rsv_opcode(self, fin, rsv, opcode): + fin = int(fin) << 7 + rsv = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + \ + (int(rsv.rsv3) << 4) + opcode = int(opcode) + + return fin | rsv | opcode + + def _serialize_frame(self, opcode, payload=b'', fin=True): + rsv = RsvBits(False, False, False) + for extension in reversed(self.extensions): + rsv, payload = extension.frame_outbound(self, opcode, rsv, payload, + fin) + + fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode) + + payload_length = len(payload) + quad_payload = False + if payload_length <= MAX_PAYLOAD_NORMAL: + first_payload = payload_length + second_payload = None + elif payload_length <= MAX_PAYLOAD_TWO_BYTE: + first_payload = PAYLOAD_LENGTH_TWO_BYTE + second_payload = payload_length + else: + first_payload = PAYLOAD_LENGTH_EIGHT_BYTE + second_payload = payload_length + quad_payload = True + + if self.client: + first_payload |= 1 << 7 + + header = bytearray([fin_rsv_opcode, first_payload]) + if second_payload is not None: + if opcode.iscontrol(): + raise ValueError("payload too long for control frame") + if quad_payload: + header += bytearray(struct.pack('!Q', second_payload)) + else: + header += bytearray(struct.pack('!H', second_payload)) + + if self.client: + # "The masking key is a 32-bit value chosen at random by the + # client. When preparing a masked frame, the client MUST pick a + # fresh masking key from the set of allowed 32-bit values. The + # masking key needs to be unpredictable; thus, the masking key + # MUST be derived from a strong source of entropy, and the masking + # key for a given frame MUST NOT make it simple for a server/proxy + # to predict the masking key for a subsequent frame. The + # unpredictability of the masking key is essential to prevent + # authors of malicious applications from selecting the bytes that + # appear on the wire." + # -- https://tools.ietf.org/html/rfc6455#section-5.3 + masking_key = os.urandom(4) + masker = XorMaskerSimple(masking_key) + return header + masking_key + masker.process(payload) + + return header + payload diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 54d8120de..34dcba066 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -1,10 +1,10 @@ import socket from OpenSSL import SSL -from wsproto import events -from wsproto.connection import ConnectionType, WSConnection -from wsproto.extensions import PerMessageDeflate -from wsproto.frame_protocol import Opcode +from mitmproxy.contrib.wsproto import events +from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection +from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate +from mitmproxy.contrib.wsproto.frame_protocol import Opcode from mitmproxy import exceptions from mitmproxy import flow diff --git a/setup.py b/setup.py index 54c2811d9..ad792881e 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ setup( "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! "click>=6.2, <7", "cryptography>=2.0,<2.2", + 'h11>=0.7.0,<0.8', "h2>=3.0, <4", "hyperframe>=5.0, <6", "kaitaistruct>=0.7, <0.8",