diff --git a/mitmproxy/net/websocket_utils.py b/mitmproxy/net/websocket_utils.py index 84e35fa27..f8608b07b 100644 --- a/mitmproxy/net/websocket_utils.py +++ b/mitmproxy/net/websocket_utils.py @@ -7,12 +7,70 @@ Spec: https://tools.ietf.org/html/rfc6455 import base64 import hashlib import os +import struct from wsproto.utilities import ACCEPT_GUID from wsproto.handshake import WEBSOCKET_VERSION +from wsproto.frame_protocol import RsvBits, Header, Frame, XorMaskerSimple, XorMaskerNull from mitmproxy.net import http -from mitmproxy.utils import strutils +from mitmproxy.utils import bits, strutils + + +def read_raw_frame(rfile): + consumed_bytes = b'' + + def consume(len): + nonlocal consumed_bytes + d = rfile.safe_read(len) + consumed_bytes += d + return d + + first_byte, second_byte = consume(2) + fin = bits.getbit(first_byte, 7) + rsv1 = bits.getbit(first_byte, 6) + rsv2 = bits.getbit(first_byte, 5) + rsv3 = bits.getbit(first_byte, 4) + opcode = first_byte & 0xF + mask_bit = bits.getbit(second_byte, 7) + length_code = second_byte & 0x7F + + # payload_len > 125 indicates you need to read more bytes + # to get the actual payload length + if length_code <= 125: + payload_len = length_code + elif length_code == 126: + payload_len, = struct.unpack("!H", consume(2)) + else: # length_code == 127: + payload_len, = struct.unpack("!Q", consume(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = consume(4) + masker = XorMaskerSimple(masking_key) + else: + masking_key = None + masker = XorMaskerNull() + + header = Header( + fin=fin, + rsv=RsvBits(rsv1, rsv2, rsv3), + opcode=opcode, + payload_len=payload_len, + masking_key=masking_key, + ) + + masked_payload = consume(payload_len) + payload = masker.process(masked_payload) + + frame = Frame( + opcode=opcode, + payload=payload, + frame_finished=fin, + message_finished=fin + ) + + return header, frame, consumed_bytes def client_handshake_headers(version=None, key=None, protocol=None, extensions=None): diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 0134427b0..e7697169c 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -8,15 +8,12 @@ from wsproto.connection import ConnectionType from wsproto.events import AcceptConnection, CloseConnection, Message, Ping, Request from wsproto.extensions import PerMessageDeflate -from mitmproxy import exceptions -from mitmproxy import flow +from mitmproxy import exceptions, flow from mitmproxy.proxy.protocol import base -from mitmproxy.net import tcp +from mitmproxy.net import tcp, websocket_utils from mitmproxy.websocket import WebSocketFlow, WebSocketMessage from mitmproxy.utils import strutils -from pathod.language import websockets_frame - class WebSocketLayer(base.Layer): """ @@ -79,6 +76,10 @@ class WebSocketLayer(base.Layer): assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection) def _handle_event(self, event, source_conn, other_conn, is_server): + self.log( + "WebSocket Event from {}: {}".format("server" if is_server else "client", event), + "debug" + ) if isinstance(event, events.Message): return self._handle_message(event, source_conn, other_conn, is_server) elif isinstance(event, events.Ping): @@ -199,9 +200,17 @@ class WebSocketLayer(base.Layer): other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn is_server = (source_conn == self.server_conn) - # TODO: replace this method from pathod with a stack-agnostic version - frame = websockets_frame.Frame.from_file(source_conn.rfile) - data = self.connections[source_conn].receive_data(bytes(frame)) + header, frame, consumed_bytes = websocket_utils.read_raw_frame(source_conn.rfile) + self.log( + "WebSocket Frame from {}: {}, {}".format( + "server" if is_server else "client", + header, + frame, + ), + "debug" + ) + + data = self.connections[source_conn].receive_data(consumed_bytes) source_conn.send(data) if close_received: diff --git a/test/mitmproxy/net/test_websocket_utils.py b/test/mitmproxy/net/test_websocket_utils.py index 48072b5e0..3c0dbbe4a 100644 --- a/test/mitmproxy/net/test_websocket_utils.py +++ b/test/mitmproxy/net/test_websocket_utils.py @@ -1,9 +1,42 @@ +import pytest +from io import BytesIO from unittest import mock +from wsproto.frame_protocol import Opcode, RsvBits, Header, Frame + from mitmproxy.net.http import Headers from mitmproxy.net import websocket_utils +@pytest.mark.parametrize("input,masking_key,payload_length", [ + (b'\x01\rserver-foobar', None, 13), + (b'\x01\x8dasdf\x12\x16\x16\x10\x04\x01I\x00\x0e\x1c\x06\x07\x13', b'asdf', 13), + (b'\x01~\x04\x00server-foobar', None, 1024), + (b'\x01\x7f\x00\x00\x00\x00\x00\x02\x00\x00server-foobar', None, 131072), +]) +def test_read_raw_frame(input, masking_key, payload_length): + bio = BytesIO(input) + bio.safe_read = bio.read + + header, frame, consumed_bytes = websocket_utils.read_raw_frame(bio) + assert header == \ + Header( + fin=False, + rsv=RsvBits(rsv1=False, rsv2=False, rsv3=False), + opcode=Opcode.TEXT, + payload_len=payload_length, + masking_key=masking_key, + ) + assert frame == \ + Frame( + opcode=Opcode.TEXT, + payload=b'server-foobar', + frame_finished=False, + message_finished=False, + ) + assert consumed_bytes == input + + @mock.patch('os.urandom', return_value=b'pumpkinspumpkins') def test_client_handshake_headers(_): assert websocket_utils.client_handshake_headers() == \ diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index c52c292b6..7e28cdfb0 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -146,12 +146,12 @@ class TestSimple(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.TEXT, payload=b'server-foobar'))) wfile.flush() - frame = websockets_frame.Frame.from_file(rfile) - wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) + header, frame, _ = websocket_utils.read_raw_frame(rfile) + wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload))) wfile.flush() - frame = websockets_frame.Frame.from_file(rfile) - wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) + header, frame, _ = websocket_utils.read_raw_frame(rfile) + wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload))) wfile.flush() @pytest.mark.parametrize('streaming', [True, False]) @@ -163,19 +163,19 @@ class TestSimple(_WebSocketTest): self.proxy.set_addons(Stream()) self.setup_connection() - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame.payload == b'server-foobar' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame.payload == b'self.client-foobar' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.BINARY, payload=b'\xde\xad\xbe\xef'))) self.client.wfile.flush() - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame.payload == b'\xde\xad\xbe\xef' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) @@ -204,19 +204,19 @@ class TestSimple(_WebSocketTest): self.proxy.set_addons(Addon()) self.setup_connection() - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame.payload == b'foo' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame.payload == b'foo' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.BINARY, payload=b'\xde\xad\xbe\xef'))) self.client.wfile.flush() - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame.payload == b'foo' @@ -236,7 +236,7 @@ class TestKillFlow(_WebSocketTest): self.setup_connection() with pytest.raises(exceptions.TcpDisconnect): - websockets_frame.Frame.from_file(self.client.rfile) + _, _, _ = websocket_utils.read_raw_frame(self.client.rfile) class TestSimpleTLS(_WebSocketTest): @@ -247,20 +247,20 @@ class TestSimpleTLS(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.TEXT, payload=b'server-foobar'))) wfile.flush() - frame = websockets_frame.Frame.from_file(rfile) - wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) + header, frame, _ = websocket_utils.read_raw_frame(rfile) + wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload))) wfile.flush() def test_simple_tls(self): self.setup_connection() - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame.payload == b'server-foobar' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.TEXT, payload=b'self.client-foobar'))) self.client.wfile.flush() - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame.payload == b'self.client-foobar' self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) @@ -274,8 +274,8 @@ class TestPing(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.PING, payload=b'foobar'))) wfile.flush() - frame = websockets_frame.Frame.from_file(rfile) - assert frame.header.opcode == Opcode.PONG + header, frame, _ = websocket_utils.read_raw_frame(rfile) + assert header.opcode == Opcode.PONG assert frame.payload == b'foobar' wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.PONG, payload=b'done'))) @@ -283,17 +283,17 @@ class TestPing(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.CLOSE))) wfile.flush() - websockets_frame.Frame.from_file(rfile) + _, _, _ = websocket_utils.read_raw_frame(rfile) @pytest.mark.asyncio async def test_ping(self): self.setup_connection() - frame = websockets_frame.Frame.from_file(self.client.rfile) - websockets_frame.Frame.from_file(self.client.rfile) + header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _ = websocket_utils.read_raw_frame(self.client.rfile) self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) self.client.wfile.flush() - assert frame.header.opcode == Opcode.PING + assert header.opcode == Opcode.PING assert frame.payload == b'' # We don't send payload to other end assert await self.master.await_log("Pong Received from server", "info") @@ -303,8 +303,8 @@ class TestPong(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): - frame = websockets_frame.Frame.from_file(rfile) - assert frame.header.opcode == Opcode.PING + header, frame, _ = websocket_utils.read_raw_frame(rfile) + assert header.opcode == Opcode.PING assert frame.payload == b'' wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.PONG, payload=frame.payload))) @@ -312,7 +312,7 @@ class TestPong(_WebSocketTest): wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.CLOSE))) wfile.flush() - websockets_frame.Frame.from_file(rfile) + _ = websocket_utils.read_raw_frame(rfile) @pytest.mark.asyncio async def test_pong(self): @@ -321,12 +321,12 @@ class TestPong(_WebSocketTest): self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.PING, payload=b'foobar'))) self.client.wfile.flush() - frame = websockets_frame.Frame.from_file(self.client.rfile) - websockets_frame.Frame.from_file(self.client.rfile) + header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + _ = websocket_utils.read_raw_frame(self.client.rfile) self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) self.client.wfile.flush() - assert frame.header.opcode == Opcode.PONG + assert header.opcode == Opcode.PONG assert frame.payload == b'foobar' assert await self.master.await_log("pong received") @@ -335,13 +335,13 @@ class TestClose(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): - frame = websockets_frame.Frame.from_file(rfile) - wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) + header, frame, _ = websocket_utils.read_raw_frame(rfile) + wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=header.opcode, payload=frame.payload))) wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.CLOSE))) wfile.flush() with pytest.raises(exceptions.TcpDisconnect): - websockets_frame.Frame.from_file(rfile) + _, _, _ = websocket_utils.read_raw_frame(rfile) def test_close(self): self.setup_connection() @@ -349,9 +349,9 @@ class TestClose(_WebSocketTest): self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE))) self.client.wfile.flush() - websockets_frame.Frame.from_file(self.client.rfile) + _ = websocket_utils.read_raw_frame(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets_frame.Frame.from_file(self.client.rfile) + _ = websocket_utils.read_raw_frame(self.client.rfile) def test_close_payload_1(self): self.setup_connection() @@ -359,9 +359,9 @@ class TestClose(_WebSocketTest): self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE, payload=b'\00\42'))) self.client.wfile.flush() - websockets_frame.Frame.from_file(self.client.rfile) + _ = websocket_utils.read_raw_frame(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets_frame.Frame.from_file(self.client.rfile) + _ = websocket_utils.read_raw_frame(self.client.rfile) def test_close_payload_2(self): self.setup_connection() @@ -369,9 +369,9 @@ class TestClose(_WebSocketTest): self.client.wfile.write(bytes(websockets_frame.Frame(fin=1, mask=1, opcode=Opcode.CLOSE, payload=b'\00\42foobar'))) self.client.wfile.flush() - websockets_frame.Frame.from_file(self.client.rfile) + _ = websocket_utils.read_raw_frame(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets_frame.Frame.from_file(self.client.rfile) + _ = websocket_utils.read_raw_frame(self.client.rfile) class TestInvalidFrame(_WebSocketTest): @@ -384,8 +384,7 @@ class TestInvalidFrame(_WebSocketTest): def test_invalid_frame(self): self.setup_connection() - # with pytest.raises(exceptions.TcpDisconnect): - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) code, = struct.unpack('!H', frame.payload[:2]) assert code == 1002 assert frame.payload[2:].startswith(b'Invalid opcode') @@ -410,11 +409,11 @@ class TestStreaming(_WebSocketTest): frame = None if not streaming: with pytest.raises(exceptions.TcpDisconnect): # Reader.safe_read get nothing as result - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame is None else: - frame = websockets_frame.Frame.from_file(self.client.rfile) + _, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) assert frame assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received @@ -427,33 +426,33 @@ class TestExtension(_WebSocketTest): wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00') wfile.flush() - frame = websockets_frame.Frame.from_file(rfile) - assert frame.header.rsv1 + header, _, _ = websocket_utils.read_raw_frame(rfile) + assert header.rsv.rsv1 wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00') wfile.flush() - frame = websockets_frame.Frame.from_file(rfile) - assert frame.header.rsv1 + header, _, _ = websocket_utils.read_raw_frame(rfile) + assert header.rsv.rsv1 wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00') wfile.flush() def test_extension(self): self.setup_connection(True) - frame = websockets_frame.Frame.from_file(self.client.rfile) - assert frame.header.rsv1 + header, _, _ = websocket_utils.read_raw_frame(self.client.rfile) + assert header.rsv.rsv1 self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v') self.client.wfile.flush() - frame = websockets_frame.Frame.from_file(self.client.rfile) - assert frame.header.rsv1 + header, _, _ = websocket_utils.read_raw_frame(self.client.rfile) + assert header.rsv.rsv1 self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c') self.client.wfile.flush() - frame = websockets_frame.Frame.from_file(self.client.rfile) - assert frame.header.rsv1 + header, _, _ = websocket_utils.read_raw_frame(self.client.rfile) + assert header.rsv.rsv1 assert len(self.master.state.flows[1].messages) == 5 assert self.master.state.flows[1].messages[0].content == 'server-foobar' @@ -482,8 +481,8 @@ class TestInjectMessageClient(_WebSocketTest): self.proxy.set_addons(Inject()) self.setup_connection() - frame = websockets_frame.Frame.from_file(self.client.rfile) - assert frame.header.opcode == Opcode.TEXT + header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + assert header.opcode == Opcode.TEXT assert frame.payload == b'This is an injected message!' @@ -491,8 +490,8 @@ class TestInjectMessageServer(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): - frame = websockets_frame.Frame.from_file(rfile) - assert frame.header.opcode == Opcode.TEXT + header, frame, _ = websocket_utils.read_raw_frame(rfile) + assert header.opcode == Opcode.TEXT success = frame.payload == b'This is an injected message!' wfile.write(bytes(websockets_frame.Frame(fin=1, opcode=Opcode.TEXT, payload=str(success).encode()))) @@ -506,6 +505,6 @@ class TestInjectMessageServer(_WebSocketTest): self.proxy.set_addons(Inject()) self.setup_connection() - frame = websockets_frame.Frame.from_file(self.client.rfile) - assert frame.header.opcode == Opcode.TEXT + header, frame, _ = websocket_utils.read_raw_frame(self.client.rfile) + assert header.opcode == Opcode.TEXT assert frame.payload == b'True'