diff --git a/mitmproxy/addons/streambodies.py b/mitmproxy/addons/streambodies.py index 16c5978d1..b98ed1fa3 100644 --- a/mitmproxy/addons/streambodies.py +++ b/mitmproxy/addons/streambodies.py @@ -2,6 +2,7 @@ from mitmproxy.net.http import http1 from mitmproxy import exceptions from mitmproxy import ctx from mitmproxy.utils import human +from mitmproxy import websocket class StreamBodies: @@ -17,6 +18,13 @@ class StreamBodies: def run(self, f, is_request): if self.max_size: + if isinstance(f, websocket.WebSocketFlow): + f.stream = True + ctx.log.info("Streaming WebSocket message {client} - {server}".format( + client=human.format_address(f.client_conn.address), + server=human.format_address(f.server_conn.address)) + ) + return r = f.request if is_request else f.response try: expected_size = http1.expected_http_body_size( @@ -30,9 +38,11 @@ class StreamBodies: r.stream = r.stream or True ctx.log.info("Streaming {} {}".format("response from" if not is_request else "request to", f.request.host)) - # FIXME! Request streaming doesn't work at the moment. def requestheaders(self, f): self.run(f, True) def responseheaders(self, f): self.run(f, False) + + def websocket_start(self, f): + self.run(f, False) diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 373c6479d..19546eb2e 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -55,6 +55,7 @@ class WebSocketLayer(base.Layer): return self._handle_unknown_frame(frame, source_conn, other_conn, is_server) def _handle_data_frame(self, frame, source_conn, other_conn, is_server): + fb = self.server_frame_buffer if is_server else self.client_frame_buffer fb.append(frame) @@ -70,43 +71,51 @@ class WebSocketLayer(base.Layer): self.flow.messages.append(websocket_message) self.channel.ask("websocket_message", self.flow) - def get_chunk(payload): - if len(payload) == length: - # message has the same length, we can reuse the same sizes - pos = 0 - for s in original_chunk_sizes: - yield payload[pos:pos + s] - pos += s + if not self.flow.stream: + def get_chunk(payload): + if len(payload) == length: + # message has the same length, we can reuse the same sizes + pos = 0 + for s in original_chunk_sizes: + yield payload[pos:pos + s] + pos += s + else: + # just re-chunk everything into 4kB frames + # header len = 4 bytes without masking key and 8 bytes with masking key + chunk_size = 4092 if is_server else 4088 + chunks = range(0, len(payload), chunk_size) + for i in chunks: + yield payload[i:i + chunk_size] + + frms = [ + websockets.Frame( + payload=chunk, + opcode=frame.header.opcode, + mask=(False if is_server else 1), + masking_key=(b'' if is_server else os.urandom(4))) + for chunk in get_chunk(websocket_message.content) + ] + + if len(frms) > 0: + frms[-1].header.fin = True else: - # just re-chunk everything into 10kB frames - chunk_size = 10240 - chunks = range(0, len(payload), chunk_size) - for i in chunks: - yield payload[i:i + chunk_size] + frms.append(websockets.Frame( + fin=True, + opcode=websockets.OPCODE.CONTINUE, + mask=(False if is_server else 1), + masking_key=(b'' if is_server else os.urandom(4)))) - frms = [ - websockets.Frame( - payload=chunk, - opcode=frame.header.opcode, - mask=(False if is_server else 1), - masking_key=(b'' if is_server else os.urandom(4))) - for chunk in get_chunk(websocket_message.content) - ] + frms[0].header.opcode = message_type + frms[0].header.rsv1 = compressed_message + + for frm in frms: + other_conn.send(bytes(frm)) - if len(frms) > 0: - frms[-1].header.fin = True else: - frms.append(websockets.Frame( - fin=True, - opcode=websockets.OPCODE.CONTINUE, - mask=(False if is_server else 1), - masking_key=(b'' if is_server else os.urandom(4)))) + other_conn.send(bytes(frame)) - frms[0].header.opcode = message_type - frms[0].header.rsv1 = compressed_message - - for frm in frms: - other_conn.send(bytes(frm)) + elif self.flow.stream: + other_conn.send(bytes(frame)) return True diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 30967a91f..ded09f655 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -45,6 +45,7 @@ class WebSocketFlow(flow.Flow): self.close_code = '(status code missing)' self.close_message = '(message missing)' self.close_reason = 'unknown status code' + self.stream = False if handshake_flow: self.client_key = websockets.get_client_key(handshake_flow.request.headers) diff --git a/test/mitmproxy/addons/test_streambodies.py b/test/mitmproxy/addons/test_streambodies.py index c6ce5e81c..547999499 100644 --- a/test/mitmproxy/addons/test_streambodies.py +++ b/test/mitmproxy/addons/test_streambodies.py @@ -29,3 +29,8 @@ def test_simple(): f = tflow.tflow(resp=True) f.response.headers["content-length"] = "invalid" tctx.cycle(sa, f) + + f = tflow.twebsocketflow() + assert not f.stream + sa.websocket_start(f) + assert f.stream diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index f78e173fc..58857f920 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -155,7 +155,13 @@ class TestSimple(_WebSocketTest): wfile.write(bytes(frame)) wfile.flush() - def test_simple(self): + @pytest.mark.parametrize('streaming', [True, False]) + def test_simple(self, streaming): + class Stream: + def websocket_start(self, f): + f.stream = streaming + + self.master.addons.add(Stream()) self.setup_connection() frame = websockets.Frame.from_file(self.client.rfile) @@ -328,3 +334,32 @@ class TestInvalidFrame(_WebSocketTest): frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == 15 assert frame.payload == b'foobar' + + +class TestStreaming(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + @pytest.mark.parametrize('streaming', [True, False]) + def test_streaming(self, streaming): + class Stream: + def websocket_start(self, f): + f.stream = streaming + + self.master.addons.add(Stream()) + self.setup_connection() + + frame = None + if not streaming: + with pytest.raises(exceptions.TcpDisconnect): # Reader.safe_read get nothing as result + frame = websockets.Frame.from_file(self.client.rfile) + assert frame is None + + else: + frame = websockets.Frame.from_file(self.client.rfile) + + assert frame + assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received