add sans-io websocket layer

This commit is contained in:
Ujjwal Verma 2017-08-01 01:00:37 +05:30 committed by Maximilian Hils
parent 6d59d213e3
commit 4c2fb7f250
3 changed files with 247 additions and 8 deletions

View File

@ -1,13 +1,16 @@
import typing import typing
from warnings import warn from warnings import warn
from unittest import mock
import sys import sys
import h11 import h11
from mitmproxy.net import http
from mitmproxy.proxy.protocol2 import events, commands, websocket from mitmproxy.proxy.protocol2 import events, commands, websocket
from mitmproxy.proxy.protocol2.context import ClientServerContext from mitmproxy.proxy.protocol2.context import ClientServerContext
from mitmproxy.proxy.protocol2.layer import Layer from mitmproxy.proxy.protocol2.layer import Layer
from mitmproxy.proxy.protocol2.utils import expect from mitmproxy.proxy.protocol2.utils import expect
from mitmproxy.net import websockets
class HTTPLayer(Layer): class HTTPLayer(Layer):
@ -22,7 +25,7 @@ class HTTPLayer(Layer):
def __init__(self, context: ClientServerContext): def __init__(self, context: ClientServerContext):
super().__init__(context) super().__init__(context)
self.state = self.read_request_headers self.state = self.read_request_headers
self.flow = mock.Mock("mitmproxy.http.HTTPFlow")
self.client_conn = h11.Connection(h11.SERVER) self.client_conn = h11.Connection(h11.SERVER)
self.server_conn = h11.Connection(h11.CLIENT) self.server_conn = h11.Connection(h11.CLIENT)
@ -63,6 +66,7 @@ class HTTPLayer(Layer):
if self.client_conn.client_is_waiting_for_100_continue: if self.client_conn.client_is_waiting_for_100_continue:
raise NotImplementedError() raise NotImplementedError()
self.flow.request.headers = http.Headers(event.headers)
self.flow_events[0].append(event) self.flow_events[0].append(event)
self.state = self.read_request_body self.state = self.read_request_body
yield from self.read_request_body() # there may already be further events. yield from self.read_request_body() # there may already be further events.
@ -108,9 +112,9 @@ class HTTPLayer(Layer):
self.state = self.read_response_body self.state = self.read_response_body
yield from self.read_response_body() # there may already be further events. yield from self.read_response_body() # there may already be further events.
elif isinstance(event, h11.InformationalResponse): elif isinstance(event, h11.InformationalResponse):
if event.status_code == 101: self.flow.response.headers = http.Headers(event.headers)
# FIXME: check if this is actually WebSocket if event.status_code == 101 and websockets.check_handshake(self.flow.response.headers):
child_layer = websocket.WebsocketLayer(self.context, None) child_layer = websocket.WebsocketLayer(self.context, self.flow)
yield from child_layer.handle_event(events.Start()) yield from child_layer.handle_event(events.Start())
self._handle_event = child_layer.handle_event self._handle_event = child_layer.handle_event
return return

View File

@ -0,0 +1,146 @@
from unittest import mock
import pytest
from . import tutils
from .. import commands
from .. import events
from .. import websocket
from mitmproxy.test import tflow
@pytest.fixture
def ws_playbook(tctx):
tctx.server.connected = True
playbook = tutils.playbook(
websocket.WebsocketLayer(
tctx,
tflow.twebsocketflow().handshake_flow
)
)
with mock.patch("os.urandom") as m:
m.return_value = b"\x10\x11\x12\x13"
yield playbook
def test_simple(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1, None)
>> events.DataReceived(tctx.client, b"\x82\x85\x10\x11\x12\x13Xt~\x7f\x7f") # Frame with payload b"Hello"
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1, None)
<< commands.SendData(tctx.server, b"\x82\x85\x10\x11\x12\x13Xt~\x7f\x7f")
>> events.DataReceived(tctx.server, b'\x81\x05Hello') # Frame with payload "Hello"
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1, None)
<< commands.SendData(tctx.client, b'\x81\x05Hello')
>> events.DataReceived(tctx.client, b'\x88\x82\x10\x11\x12\x13\x13\xf9') # Closing frame
<< commands.SendData(tctx.server, b'\x88\x82\x10\x11\x12\x13\x13\xf9')
<< commands.SendData(tctx.client, b'\x88\x02\x03\xe8')
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1, None)
>> events.DataReceived(tctx.server, b'\x81\x05Hello')
<< None
)
assert len(f().messages) == 2
def test_server_close(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1, None)
>> events.DataReceived(tctx.server, b'\x88\x02\x03\xe8')
<< commands.SendData(tctx.client, b'\x88\x02\x03\xe8')
<< commands.SendData(tctx.server, b'\x88\x82\x10\x11\x12\x13\x13\xf9')
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1, None)
<< commands.CloseConnection(tctx.client)
)
def test_ping_pong(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1, None)
>> events.DataReceived(tctx.client, b'\x89\x80\x10\x11\x12\x13') # Ping
<< commands.Log("info", "Websocket PING received ")
>> events.HookReply(-1, None)
<< commands.SendData(tctx.server, b'\x89\x80\x10\x11\x12\x13')
<< commands.SendData(tctx.client, b'\x8a\x00')
>> events.DataReceived(tctx.server, b'\x8a\x00') # Pong
<< commands.Log("info", "Websocket PONG received ")
>> events.HookReply(-1, None)
)
def test_connection_failed(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1, None)
>> events.DataReceived(tctx.client, b"Not a valid frame")
<< commands.SendData(tctx.server, b'\x88\x94\x10\x11\x12\x13\x13\xfb[}fp~zt1}cs~vv0!jv')
<< commands.SendData(tctx.client, b'\x88\x14\x03\xeaInvalid opcode 0xe')
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1, None)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1, None)
)
def test_extension(tctx):
f = tutils.Placeholder()
tctx.server.connected = True
handshake_flow = tflow.twebsocketflow().handshake_flow
handshake_flow.request.headers["sec-websocket-extensions"] = "permessage-deflate;"
handshake_flow.response.headers["sec-websocket-extensions"] = "permessage-deflate;"
playbook = tutils.playbook(websocket.WebsocketLayer(tctx, handshake_flow))
with mock.patch("os.urandom") as m:
m.return_value = b"\x10\x11\x12\x13"
assert (
playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1, None)
>> events.DataReceived(tctx.client, b'\xc1\x87\x10\x11\x12\x13\xe2Y\xdf\xda\xd9\x16\x12') # Compressed Frame
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1, None)
<< commands.SendData(tctx.server, b'\xc1\x87\x10\x11\x12\x13\xe2Y\xdf\xda\xd9\x16\x12')
>> events.DataReceived(tctx.server, b'\xc1\x07\xf2H\xcd\xc9\xc9\x07\x00') # Compressed Frame
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1, None)
<< commands.SendData(tctx.client, b'\xc1\x07\xf2H\xcd\xc9\xc9\x07\x00')
)
assert len(f().messages) == 2
assert f().messages[0].content == "Hello"
assert f().messages[1].content == "Hello"
def test_connection_closed(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1, None)
>> events.ConnectionClosed(tctx.server)
<< commands.Log("error", "Connection closed abnormally")
>> events.HookReply(-1, None)
<< commands.CloseConnection(tctx.client)
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1, None)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1, None)
)
assert f().error

View File

@ -1,8 +1,10 @@
from mitmproxy import websocket, http from mitmproxy import websocket, http, flow
from mitmproxy.proxy.protocol2 import events, commands from mitmproxy.proxy.protocol2 import events, commands
from mitmproxy.proxy.protocol2.context import ClientServerContext from mitmproxy.proxy.protocol2.context import ClientServerContext
from mitmproxy.proxy.protocol2.layer import Layer from mitmproxy.proxy.protocol2.layer import Layer
from mitmproxy.proxy.protocol2.utils import expect from mitmproxy.proxy.protocol2.utils import expect
from wsproto import connection as wsconn
from wsproto import events as wsevents
class WebsocketLayer(Layer): class WebsocketLayer(Layer):
@ -16,14 +18,101 @@ class WebsocketLayer(Layer):
super().__init__(context) super().__init__(context)
self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow) self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow)
assert context.server.connected assert context.server.connected
self.client_frame_buffer = []
self.server_frame_buffer = []
extension = self.flow.server_extensions
self.client_conn = wsconn.WSConnection(wsconn.SERVER, wsconn.ConnectionState.OPEN,
extensions=[extension] if extension else None)
self.server_conn = wsconn.WSConnection(wsconn.CLIENT, wsconn.ConnectionState.OPEN,
extensions=[extension] if extension else None)
@expect(events.Start) @expect(events.Start)
def start(self, _) -> commands.TCommandGenerator: def start(self, _) -> commands.TCommandGenerator:
yield from () yield commands.Hook("websocket_start", self.flow)
self._handle_event = self.relay_messages self._handle_event = self.relay_messages
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed) @expect(events.DataReceived, events.ConnectionClosed)
def relay_messages(self, event: events.Event) -> commands.TCommandGenerator: def relay_messages(self, event: events.Event) -> commands.TCommandGenerator:
raise NotImplementedError() if isinstance(event, events.DataReceived):
from_client = event.connection == self.context.client
if from_client:
source = self.client_conn
other = self.server_conn
fb = self.client_frame_buffer
send_to = self.context.server
else:
source = self.server_conn
other = self.client_conn
fb = self.server_frame_buffer
send_to = self.context.client
_handle_event = start source.receive_bytes(event.data)
for ws_event in source.events():
if isinstance(ws_event, wsevents.DataReceived):
fb.append(ws_event.data)
if ws_event.message_finished:
if isinstance(ws_event, wsevents.BytesReceived):
payload = b"".join(fb)
else:
payload = "".join(fb)
fb.clear()
websocket_message = websocket.WebSocketMessage(0x1 if isinstance(ws_event, wsevents.TextReceived) else 0x2,
from_client, payload)
self.flow.messages.append(websocket_message)
yield commands.Hook("websocket_message", self.flow)
other.send_data(ws_event.data, ws_event.message_finished)
yield commands.SendData(send_to, other.bytes_to_send())
elif isinstance(ws_event, wsevents.PingReceived):
yield commands.Log(
"info",
"Websocket PING received {}".format(ws_event.payload.decode())
)
other.ping()
yield commands.SendData(send_to, other.bytes_to_send())
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
elif isinstance(ws_event, wsevents.PongReceived):
yield commands.Log(
"info",
"Websocket PONG received {}".format(ws_event.payload.decode())
)
elif isinstance(ws_event, wsevents.ConnectionClosed):
other.close(ws_event.code, ws_event.reason)
yield commands.SendData(send_to, other.bytes_to_send())
# FIXME: Wait for other end to actually send the closing frame
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
if ws_event.code != 1000:
self.flow.error = flow.Error(
"WebSocket connection closed unexpectedly by {}: {}".format(
"client" if from_client else "server",
ws_event.reason
)
)
yield commands.Hook("websocket_error", self.flow)
yield commands.Hook("websocket_end", self.flow)
if not from_client:
yield commands.CloseConnection(self.context.client)
self._handle_event = self.done
elif isinstance(event, events.ConnectionClosed):
yield commands.Log("error", "Connection closed abnormally")
self.flow.error = flow.Error(
"WebSocket connection closed unexpectedly by {}".format(
"client" if event.connection == self.context.client else "server"
)
)
if event.connection == self.context.server:
yield commands.CloseConnection(self.context.client)
yield commands.Hook("websocket_error", self.flow)
yield commands.Hook("websocket_end", self.flow)
self._handle_event = self.done
@expect(events.DataReceived, events.ConnectionClosed)
def done(self, _):
yield from ()