diff --git a/examples/complex/websocket_inject_message.py b/examples/complex/websocket_inject_message.py new file mode 100644 index 000000000..e9c3ea0c4 --- /dev/null +++ b/examples/complex/websocket_inject_message.py @@ -0,0 +1,23 @@ +""" +This example shows how to inject a WebSocket message to the client. +Every new WebSocket connection will trigger a new asyncio task that +periodically injects a new message to the client. +""" +import asyncio +import mitmproxy.websocket + + +class InjectWebSocketMessage: + + async def inject(self, flow: mitmproxy.websocket.WebSocketFlow): + i = 0 + while not flow.ended and not flow.error: + await asyncio.sleep(5) + flow.inject_message(flow.client_conn, 'This is the #{} an injected message!'.format(i)) + i += 1 + + def websocket_start(self, flow): + asyncio.get_event_loop().create_task(self.inject(flow)) + + +addons = [InjectWebSocketMessage()] diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 7349c3259..0d1964a60 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -1,3 +1,4 @@ +import queue import socket from OpenSSL import SSL @@ -165,8 +166,18 @@ class WebSocketLayer(base.Layer): return False + def _inject_messages(self, endpoint, message_queue): + while True: + try: + payload = message_queue.get_nowait() + self.connections[endpoint].send_data(payload, final=True) + data = self.connections[endpoint].bytes_to_send() + endpoint.send(data) + except queue.Empty: + break + def __call__(self): - self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self) + self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow) self.flow.metadata['websocket_handshake'] = self.handshake_flow.id self.handshake_flow.metadata['websocket_flow'] = self.flow.id self.channel.ask("websocket_start", self.flow) @@ -176,6 +187,9 @@ class WebSocketLayer(base.Layer): try: while not self.channel.should_exit.is_set(): + self._inject_messages(self.client_conn, self.flow._inject_messages_client) + self._inject_messages(self.server_conn, self.flow._inject_messages_server) + r = tcp.ssl_read_select(conns, 0.1) for conn in r: source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn @@ -198,4 +212,5 @@ class WebSocketLayer(base.Layer): self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e))) self.channel.tell("websocket_error", self.flow) finally: + self.flow.ended = True self.channel.tell("websocket_end", self.flow) diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 9de2e26e3..f13b9eec3 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -1,4 +1,5 @@ import time +import queue from typing import List, Optional from wsproto.frame_protocol import CloseReason @@ -77,6 +78,11 @@ class WebSocketFlow(flow.Flow): """True of this connection is streaming directly to the other endpoint.""" self.handshake_flow = handshake_flow """The HTTP flow containing the initial WebSocket handshake.""" + self.ended = False + """True when the WebSocket connection has been closed.""" + + self._inject_messages_client = queue.Queue(maxsize=1) + self._inject_messages_server = queue.Queue(maxsize=1) if handshake_flow: self.client_key = websockets.get_client_key(handshake_flow.request.headers) @@ -134,3 +140,25 @@ class WebSocketFlow(flow.Flow): direction="->" if message.from_client else "<-", endpoint=self.handshake_flow.request.path, ) + + def inject_message(self, endpoint, payload): + """ + Inject and send a full WebSocket message to the remote endpoint. + This might corrupt your WebSocket connection! Be careful! + + The endpoint needs to be either flow.client_conn or flow.server_conn. + + If ``payload`` is of type ``bytes`` then the message is flagged as + being binary If it is of type ``str`` encoded as UTF-8 and sent as + text. + + :param payload: The message body to send. + :type payload: ``bytes`` or ``str`` + """ + + if endpoint == self.client_conn: + self._inject_messages_client.put(payload) + elif endpoint == self.server_conn: + self._inject_messages_server.put(payload) + else: + raise ValueError('Invalid endpoint') diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 1f4e2bca9..0b26ed29e 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -467,3 +467,46 @@ class TestExtension(_WebSocketTest): assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY + + +class TestInjectMessageClient(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + pass + + def test_inject_message_client(self): + class Inject: + def websocket_start(self, flow): + flow.inject_message(flow.client_conn, 'This is an injected message!') + + self.proxy.set_addons(Inject()) + self.setup_connection() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.opcode == websockets.OPCODE.TEXT + assert frame.payload == b'This is an injected message!' + + +class TestInjectMessageServer(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.TEXT + success = frame.payload == b'This is an injected message!' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=str(success).encode()))) + wfile.flush() + + def test_inject_message_server(self): + class Inject: + def websocket_start(self, flow): + flow.inject_message(flow.server_conn, 'This is an injected message!') + + self.proxy.set_addons(Inject()) + self.setup_connection() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.opcode == websockets.OPCODE.TEXT + assert frame.payload == b'True' diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index fcacec366..bd4bb4c93 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -92,3 +92,15 @@ class TestWebSocketFlow: assert not f.messages[-1].killed f.messages[-1].kill() assert f.messages[-1].killed + + def test_inject_message(self): + f = tflow.twebsocketflow() + + with pytest.raises(ValueError): + f.inject_message(None, 'foobar') + + f.inject_message(f.client_conn, 'foobar') + assert f._inject_messages_client.qsize() == 1 + + f.inject_message(f.server_conn, 'foobar') + assert f._inject_messages_client.qsize() == 1