From 4c2fb7f25048dff9df8122e7a49217978412a030 Mon Sep 17 00:00:00 2001 From: Ujjwal Verma Date: Tue, 1 Aug 2017 01:00:37 +0530 Subject: [PATCH] add sans-io websocket layer --- mitmproxy/proxy/protocol2/http.py | 12 +- .../proxy/protocol2/test/test_websocket.py | 146 ++++++++++++++++++ mitmproxy/proxy/protocol2/websocket.py | 97 +++++++++++- 3 files changed, 247 insertions(+), 8 deletions(-) create mode 100644 mitmproxy/proxy/protocol2/test/test_websocket.py diff --git a/mitmproxy/proxy/protocol2/http.py b/mitmproxy/proxy/protocol2/http.py index 838c99cda..02aacff69 100644 --- a/mitmproxy/proxy/protocol2/http.py +++ b/mitmproxy/proxy/protocol2/http.py @@ -1,13 +1,16 @@ import typing from warnings import warn +from unittest import mock import sys import h11 +from mitmproxy.net import http from mitmproxy.proxy.protocol2 import events, commands, websocket from mitmproxy.proxy.protocol2.context import ClientServerContext from mitmproxy.proxy.protocol2.layer import Layer from mitmproxy.proxy.protocol2.utils import expect +from mitmproxy.net import websockets class HTTPLayer(Layer): @@ -22,7 +25,7 @@ class HTTPLayer(Layer): def __init__(self, context: ClientServerContext): super().__init__(context) self.state = self.read_request_headers - + self.flow = mock.Mock("mitmproxy.http.HTTPFlow") self.client_conn = h11.Connection(h11.SERVER) self.server_conn = h11.Connection(h11.CLIENT) @@ -63,6 +66,7 @@ class HTTPLayer(Layer): if self.client_conn.client_is_waiting_for_100_continue: raise NotImplementedError() + self.flow.request.headers = http.Headers(event.headers) self.flow_events[0].append(event) self.state = self.read_request_body yield from self.read_request_body() # there may already be further events. @@ -108,9 +112,9 @@ class HTTPLayer(Layer): self.state = self.read_response_body yield from self.read_response_body() # there may already be further events. elif isinstance(event, h11.InformationalResponse): - if event.status_code == 101: - # FIXME: check if this is actually WebSocket - child_layer = websocket.WebsocketLayer(self.context, None) + self.flow.response.headers = http.Headers(event.headers) + if event.status_code == 101 and websockets.check_handshake(self.flow.response.headers): + child_layer = websocket.WebsocketLayer(self.context, self.flow) yield from child_layer.handle_event(events.Start()) self._handle_event = child_layer.handle_event return diff --git a/mitmproxy/proxy/protocol2/test/test_websocket.py b/mitmproxy/proxy/protocol2/test/test_websocket.py new file mode 100644 index 000000000..d6a88b97f --- /dev/null +++ b/mitmproxy/proxy/protocol2/test/test_websocket.py @@ -0,0 +1,146 @@ +from unittest import mock + +import pytest + +from . import tutils +from .. import commands +from .. import events +from .. import websocket +from mitmproxy.test import tflow + + +@pytest.fixture +def ws_playbook(tctx): + tctx.server.connected = True + playbook = tutils.playbook( + websocket.WebsocketLayer( + tctx, + tflow.twebsocketflow().handshake_flow + ) + ) + with mock.patch("os.urandom") as m: + m.return_value = b"\x10\x11\x12\x13" + yield playbook + + +def test_simple(tctx, ws_playbook): + f = tutils.Placeholder() + + assert ( + ws_playbook + << commands.Hook("websocket_start", f) + >> events.HookReply(-1, None) + >> events.DataReceived(tctx.client, b"\x82\x85\x10\x11\x12\x13Xt~\x7f\x7f") # Frame with payload b"Hello" + << commands.Hook("websocket_message", f) + >> events.HookReply(-1, None) + << commands.SendData(tctx.server, b"\x82\x85\x10\x11\x12\x13Xt~\x7f\x7f") + >> events.DataReceived(tctx.server, b'\x81\x05Hello') # Frame with payload "Hello" + << commands.Hook("websocket_message", f) + >> events.HookReply(-1, None) + << commands.SendData(tctx.client, b'\x81\x05Hello') + >> events.DataReceived(tctx.client, b'\x88\x82\x10\x11\x12\x13\x13\xf9') # Closing frame + << commands.SendData(tctx.server, b'\x88\x82\x10\x11\x12\x13\x13\xf9') + << commands.SendData(tctx.client, b'\x88\x02\x03\xe8') + << commands.Hook("websocket_end", f) + >> events.HookReply(-1, None) + >> events.DataReceived(tctx.server, b'\x81\x05Hello') + << None + ) + + assert len(f().messages) == 2 + + +def test_server_close(tctx, ws_playbook): + f = tutils.Placeholder() + + assert ( + ws_playbook + << commands.Hook("websocket_start", f) + >> events.HookReply(-1, None) + >> events.DataReceived(tctx.server, b'\x88\x02\x03\xe8') + << commands.SendData(tctx.client, b'\x88\x02\x03\xe8') + << commands.SendData(tctx.server, b'\x88\x82\x10\x11\x12\x13\x13\xf9') + << commands.Hook("websocket_end", f) + >> events.HookReply(-1, None) + << commands.CloseConnection(tctx.client) + ) + + +def test_ping_pong(tctx, ws_playbook): + f = tutils.Placeholder() + + assert ( + ws_playbook + << commands.Hook("websocket_start", f) + >> events.HookReply(-1, None) + >> events.DataReceived(tctx.client, b'\x89\x80\x10\x11\x12\x13') # Ping + << commands.Log("info", "Websocket PING received ") + >> events.HookReply(-1, None) + << commands.SendData(tctx.server, b'\x89\x80\x10\x11\x12\x13') + << commands.SendData(tctx.client, b'\x8a\x00') + >> events.DataReceived(tctx.server, b'\x8a\x00') # Pong + << commands.Log("info", "Websocket PONG received ") + >> events.HookReply(-1, None) + ) + + +def test_connection_failed(tctx, ws_playbook): + f = tutils.Placeholder() + assert ( + ws_playbook + << commands.Hook("websocket_start", f) + >> events.HookReply(-1, None) + >> events.DataReceived(tctx.client, b"Not a valid frame") + << commands.SendData(tctx.server, b'\x88\x94\x10\x11\x12\x13\x13\xfb[}fp~zt1}cs~vv0!jv') + << commands.SendData(tctx.client, b'\x88\x14\x03\xeaInvalid opcode 0xe') + << commands.Hook("websocket_error", f) + >> events.HookReply(-1, None) + << commands.Hook("websocket_end", f) + >> events.HookReply(-1, None) + ) + + +def test_extension(tctx): + f = tutils.Placeholder() + tctx.server.connected = True + handshake_flow = tflow.twebsocketflow().handshake_flow + handshake_flow.request.headers["sec-websocket-extensions"] = "permessage-deflate;" + handshake_flow.response.headers["sec-websocket-extensions"] = "permessage-deflate;" + playbook = tutils.playbook(websocket.WebsocketLayer(tctx, handshake_flow)) + with mock.patch("os.urandom") as m: + m.return_value = b"\x10\x11\x12\x13" + assert ( + playbook + << commands.Hook("websocket_start", f) + >> events.HookReply(-1, None) + >> events.DataReceived(tctx.client, b'\xc1\x87\x10\x11\x12\x13\xe2Y\xdf\xda\xd9\x16\x12') # Compressed Frame + << commands.Hook("websocket_message", f) + >> events.HookReply(-1, None) + << commands.SendData(tctx.server, b'\xc1\x87\x10\x11\x12\x13\xe2Y\xdf\xda\xd9\x16\x12') + >> events.DataReceived(tctx.server, b'\xc1\x07\xf2H\xcd\xc9\xc9\x07\x00') # Compressed Frame + << commands.Hook("websocket_message", f) + >> events.HookReply(-1, None) + << commands.SendData(tctx.client, b'\xc1\x07\xf2H\xcd\xc9\xc9\x07\x00') + ) + assert len(f().messages) == 2 + assert f().messages[0].content == "Hello" + assert f().messages[1].content == "Hello" + + +def test_connection_closed(tctx, ws_playbook): + f = tutils.Placeholder() + assert ( + ws_playbook + << commands.Hook("websocket_start", f) + >> events.HookReply(-1, None) + >> events.ConnectionClosed(tctx.server) + << commands.Log("error", "Connection closed abnormally") + >> events.HookReply(-1, None) + << commands.CloseConnection(tctx.client) + << commands.Hook("websocket_error", f) + >> events.HookReply(-1, None) + << commands.Hook("websocket_end", f) + >> events.HookReply(-1, None) + ) + + assert f().error diff --git a/mitmproxy/proxy/protocol2/websocket.py b/mitmproxy/proxy/protocol2/websocket.py index d0cf37ce0..8027546b0 100644 --- a/mitmproxy/proxy/protocol2/websocket.py +++ b/mitmproxy/proxy/protocol2/websocket.py @@ -1,8 +1,10 @@ -from mitmproxy import websocket, http +from mitmproxy import websocket, http, flow from mitmproxy.proxy.protocol2 import events, commands from mitmproxy.proxy.protocol2.context import ClientServerContext from mitmproxy.proxy.protocol2.layer import Layer from mitmproxy.proxy.protocol2.utils import expect +from wsproto import connection as wsconn +from wsproto import events as wsevents class WebsocketLayer(Layer): @@ -16,14 +18,101 @@ class WebsocketLayer(Layer): super().__init__(context) self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow) assert context.server.connected + self.client_frame_buffer = [] + self.server_frame_buffer = [] + extension = self.flow.server_extensions + self.client_conn = wsconn.WSConnection(wsconn.SERVER, wsconn.ConnectionState.OPEN, + extensions=[extension] if extension else None) + + self.server_conn = wsconn.WSConnection(wsconn.CLIENT, wsconn.ConnectionState.OPEN, + extensions=[extension] if extension else None) @expect(events.Start) def start(self, _) -> commands.TCommandGenerator: - yield from () + yield commands.Hook("websocket_start", self.flow) self._handle_event = self.relay_messages + _handle_event = start + @expect(events.DataReceived, events.ConnectionClosed) def relay_messages(self, event: events.Event) -> commands.TCommandGenerator: - raise NotImplementedError() + if isinstance(event, events.DataReceived): + from_client = event.connection == self.context.client + if from_client: + source = self.client_conn + other = self.server_conn + fb = self.client_frame_buffer + send_to = self.context.server + else: + source = self.server_conn + other = self.client_conn + fb = self.server_frame_buffer + send_to = self.context.client - _handle_event = start + source.receive_bytes(event.data) + + for ws_event in source.events(): + if isinstance(ws_event, wsevents.DataReceived): + fb.append(ws_event.data) + if ws_event.message_finished: + if isinstance(ws_event, wsevents.BytesReceived): + payload = b"".join(fb) + else: + payload = "".join(fb) + fb.clear() + websocket_message = websocket.WebSocketMessage(0x1 if isinstance(ws_event, wsevents.TextReceived) else 0x2, + from_client, payload) + self.flow.messages.append(websocket_message) + yield commands.Hook("websocket_message", self.flow) + + other.send_data(ws_event.data, ws_event.message_finished) + yield commands.SendData(send_to, other.bytes_to_send()) + elif isinstance(ws_event, wsevents.PingReceived): + yield commands.Log( + "info", + "Websocket PING received {}".format(ws_event.payload.decode()) + ) + other.ping() + yield commands.SendData(send_to, other.bytes_to_send()) + yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send()) + elif isinstance(ws_event, wsevents.PongReceived): + yield commands.Log( + "info", + "Websocket PONG received {}".format(ws_event.payload.decode()) + ) + + elif isinstance(ws_event, wsevents.ConnectionClosed): + other.close(ws_event.code, ws_event.reason) + yield commands.SendData(send_to, other.bytes_to_send()) + # FIXME: Wait for other end to actually send the closing frame + yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send()) + + if ws_event.code != 1000: + self.flow.error = flow.Error( + "WebSocket connection closed unexpectedly by {}: {}".format( + "client" if from_client else "server", + ws_event.reason + ) + ) + yield commands.Hook("websocket_error", self.flow) + + yield commands.Hook("websocket_end", self.flow) + if not from_client: + yield commands.CloseConnection(self.context.client) + self._handle_event = self.done + elif isinstance(event, events.ConnectionClosed): + yield commands.Log("error", "Connection closed abnormally") + self.flow.error = flow.Error( + "WebSocket connection closed unexpectedly by {}".format( + "client" if event.connection == self.context.client else "server" + ) + ) + if event.connection == self.context.server: + yield commands.CloseConnection(self.context.client) + yield commands.Hook("websocket_error", self.flow) + yield commands.Hook("websocket_end", self.flow) + self._handle_event = self.done + + @expect(events.DataReceived, events.ConnectionClosed) + def done(self, _): + yield from ()