From f5fafbfcb56bbc3fb7cca7ed32dd7b3b41c39e83 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 12 Dec 2017 21:47:24 +0100 Subject: [PATCH] vendoring of wsproto https://github.com/python-hyper/wsproto.git commit 5ea2da61266796666f5de6461aaae22e6b00deba --- mitmproxy/contrib/wsproto/compat.py | 20 + mitmproxy/contrib/wsproto/connection.py | 477 ++++++++++++++++ mitmproxy/contrib/wsproto/events.py | 81 +++ mitmproxy/contrib/wsproto/extensions.py | 257 +++++++++ mitmproxy/contrib/wsproto/frame_protocol.py | 579 ++++++++++++++++++++ mitmproxy/proxy/protocol/websocket.py | 8 +- setup.py | 1 + 7 files changed, 1419 insertions(+), 4 deletions(-) create mode 100644 mitmproxy/contrib/wsproto/compat.py create mode 100644 mitmproxy/contrib/wsproto/connection.py create mode 100644 mitmproxy/contrib/wsproto/events.py create mode 100644 mitmproxy/contrib/wsproto/extensions.py create mode 100644 mitmproxy/contrib/wsproto/frame_protocol.py diff --git a/mitmproxy/contrib/wsproto/compat.py b/mitmproxy/contrib/wsproto/compat.py new file mode 100644 index 000000000..1911f83cf --- /dev/null +++ b/mitmproxy/contrib/wsproto/compat.py @@ -0,0 +1,20 @@ +# flake8: noqa + +import sys + + +PY2 = sys.version_info.major == 2 +PY3 = sys.version_info.major == 3 + + +if PY3: + unicode = str + + def Utf8Validator(): + return None +else: + unicode = unicode + try: + from wsaccel.utf8validator import Utf8Validator + except ImportError: + from .utf8validator import Utf8Validator diff --git a/mitmproxy/contrib/wsproto/connection.py b/mitmproxy/contrib/wsproto/connection.py new file mode 100644 index 000000000..f994cd3ab --- /dev/null +++ b/mitmproxy/contrib/wsproto/connection.py @@ -0,0 +1,477 @@ +# -*- coding: utf-8 -*- +""" +wsproto/connection +~~~~~~~~~~~~~~ + +An implementation of a WebSocket connection. +""" + +import os +import base64 +import hashlib +from collections import deque + +from enum import Enum + +import h11 + +from .events import ( + ConnectionRequested, ConnectionEstablished, ConnectionClosed, + ConnectionFailed, TextReceived, BytesReceived, PingReceived, PongReceived +) +from .frame_protocol import FrameProtocol, ParseFailed, CloseReason, Opcode + + +# RFC6455, Section 1.3 - Opening Handshake +ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +class ConnectionState(Enum): + """ + RFC 6455, Section 4 - Opening Handshake + """ + CONNECTING = 0 + OPEN = 1 + CLOSING = 2 + CLOSED = 3 + + +class ConnectionType(Enum): + CLIENT = 1 + SERVER = 2 + + +CLIENT = ConnectionType.CLIENT +SERVER = ConnectionType.SERVER + + +# Some convenience utilities for working with HTTP headers +def _normed_header_dict(h11_headers): + # This mangles Set-Cookie headers. But it happens that we don't care about + # any of those, so it's OK. For every other HTTP header, if there are + # multiple instances then you're allowed to join them together with + # commas. + name_to_values = {} + for name, value in h11_headers: + name_to_values.setdefault(name, []).append(value) + name_to_normed_value = {} + for name, values in name_to_values.items(): + name_to_normed_value[name] = b", ".join(values) + return name_to_normed_value + + +# We use this for parsing the proposed protocol list, and for parsing the +# proposed and accepted extension lists. For the proposed protocol list it's +# fine, because the ABNF is just 1#token. But for the extension lists, it's +# wrong, because those can contain quoted strings, which can in turn contain +# commas. XX FIXME +def _split_comma_header(value): + return [piece.decode('ascii').strip() for piece in value.split(b',')] + + +class WSConnection(object): + """ + A low-level WebSocket connection object. + + This wraps two other protocol objects, an HTTP/1.1 protocol object used + to do the initial HTTP upgrade handshake and a WebSocket frame protocol + object used to exchange messages and other control frames. + + :param conn_type: Whether this object is on the client- or server-side of + a connection. To initialise as a client pass ``CLIENT`` otherwise + pass ``SERVER``. + :type conn_type: ``ConnectionType`` + + :param host: The hostname to pass to the server when acting as a client. + :type host: ``str`` + + :param resource: The resource (aka path) to pass to the server when acting + as a client. + :type resource: ``str`` + + :param extensions: A list of extensions to use on this connection. + Extensions should be instances of a subclass of + :class:`Extension `. + + :param subprotocols: A list of subprotocols to request when acting as a + client, ordered by preference. This has no impact on the connection + itself. + :type subprotocol: ``list`` of ``str`` + """ + + def __init__(self, conn_type, host=None, resource=None, extensions=None, + subprotocols=None): + self.client = conn_type is ConnectionType.CLIENT + + self.host = host + self.resource = resource + + self.subprotocols = subprotocols or [] + self.extensions = extensions or [] + + self.version = b'13' + + self._state = ConnectionState.CONNECTING + self._close_reason = None + + self._nonce = None + self._outgoing = b'' + self._events = deque() + self._proto = None + + if self.client: + self._upgrade_connection = h11.Connection(h11.CLIENT) + else: + self._upgrade_connection = h11.Connection(h11.SERVER) + + if self.client: + if self.host is None: + raise ValueError( + "Host must not be None for a client-side connection.") + if self.resource is None: + raise ValueError( + "Resource must not be None for a client-side connection.") + self.initiate_connection() + + def initiate_connection(self): + self._generate_nonce() + + headers = { + b"Host": self.host.encode('ascii'), + b"Upgrade": b'WebSocket', + b"Connection": b'Upgrade', + b"Sec-WebSocket-Key": self._nonce, + b"Sec-WebSocket-Version": self.version, + } + + if self.subprotocols: + headers[b"Sec-WebSocket-Protocol"] = ", ".join(self.subprotocols) + + if self.extensions: + offers = {e.name: e.offer(self) for e in self.extensions} + extensions = [] + for name, params in offers.items(): + if params is True: + extensions.append(name.encode('ascii')) + elif params: + # py34 annoyance: doesn't support bytestring formatting + extensions.append(('%s; %s' % (name, params)) + .encode("ascii")) + if extensions: + headers[b'Sec-WebSocket-Extensions'] = b', '.join(extensions) + + upgrade = h11.Request(method=b'GET', target=self.resource, + headers=headers.items()) + self._outgoing += self._upgrade_connection.send(upgrade) + + def send_data(self, payload, final=True): + """ + Send a message or part of a message to the remote peer. + + If ``final`` is ``False`` it indicates that this is part of a longer + message. If ``final`` is ``True`` it indicates that this is either a + self-contained message or the last part of a longer message. + + If ``payload`` is of type ``bytes`` then the message is flagged as + being binary If it is of type ``str`` encoded as UTF-8 and sent as + text. + + :param payload: The message body to send. + :type payload: ``bytes`` or ``str`` + + :param final: Whether there are more parts to this message to be sent. + :type final: ``bool`` + """ + + self._outgoing += self._proto.send_data(payload, final) + + def close(self, code=CloseReason.NORMAL_CLOSURE, reason=None): + self._outgoing += self._proto.close(code, reason) + self._state = ConnectionState.CLOSING + + @property + def closed(self): + return self._state is ConnectionState.CLOSED + + def bytes_to_send(self, amount=None): + """ + Return any data that is to be sent to the remote peer. + + :param amount: (optional) The maximum number of bytes to be provided. + If ``None`` or not provided it will return all available bytes. + :type amount: ``int`` + """ + + if amount is None: + data = self._outgoing + self._outgoing = b'' + else: + data = self._outgoing[:amount] + self._outgoing = self._outgoing[amount:] + + return data + + def receive_bytes(self, data): + """ + Pass some received bytes to the connection for processing. + + :param data: The data received from the remote peer. + :type data: ``bytes`` + """ + + if data is None and self._state is ConnectionState.OPEN: + # "If _The WebSocket Connection is Closed_ and no Close control + # frame was received by the endpoint (such as could occur if the + # underlying transport connection is lost), _The WebSocket + # Connection Close Code_ is considered to be 1006." + self._events.append(ConnectionClosed(CloseReason.ABNORMAL_CLOSURE)) + self._state = ConnectionState.CLOSED + return + elif data is None: + self._state = ConnectionState.CLOSED + return + + if self._state is ConnectionState.CONNECTING: + event, data = self._process_upgrade(data) + if event is not None: + self._events.append(event) + + if self._state is ConnectionState.OPEN: + self._proto.receive_bytes(data) + + def _process_upgrade(self, data): + self._upgrade_connection.receive_data(data) + while True: + try: + event = self._upgrade_connection.next_event() + except h11.RemoteProtocolError: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad HTTP message"), b'' + if event is h11.NEED_DATA: + break + elif self.client and isinstance(event, (h11.InformationalResponse, + h11.Response)): + data = self._upgrade_connection.trailing_data[0] + return self._establish_client_connection(event), data + elif not self.client and isinstance(event, h11.Request): + return self._process_connection_request(event), None + else: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad HTTP message"), b'' + + self._incoming = b'' + return None, None + + def events(self): + """ + Return a generator that provides any events that have been generated + by protocol activity. + + :returns: generator + """ + + while self._events: + yield self._events.popleft() + + if self._proto is None: + return + + try: + for frame in self._proto.received_frames(): + if frame.opcode is Opcode.PING: + assert frame.frame_finished and frame.message_finished + self._outgoing += self._proto.pong(frame.payload) + yield PingReceived(frame.payload) + + elif frame.opcode is Opcode.PONG: + assert frame.frame_finished and frame.message_finished + yield PongReceived(frame.payload) + + elif frame.opcode is Opcode.CLOSE: + code, reason = frame.payload + self.close(code, reason) + yield ConnectionClosed(code, reason) + + elif frame.opcode is Opcode.TEXT: + yield TextReceived(frame.payload, + frame.frame_finished, + frame.message_finished) + + elif frame.opcode is Opcode.BINARY: + yield BytesReceived(frame.payload, + frame.frame_finished, + frame.message_finished) + except ParseFailed as exc: + # XX FIXME: apparently autobahn intentionally deviates from the + # spec in that on protocol errors it just closes the connection + # rather than trying to send a CLOSE frame. Investigate whether we + # should do the same. + self.close(code=exc.code, reason=str(exc)) + yield ConnectionClosed(exc.code, reason=str(exc)) + + def _generate_nonce(self): + # os.urandom may be overkill for this use case, but I don't think this + # is a bottleneck, and better safe than sorry... + self._nonce = base64.b64encode(os.urandom(16)) + + def _generate_accept_token(self, token): + accept_token = token + ACCEPT_GUID + accept_token = hashlib.sha1(accept_token).digest() + return base64.b64encode(accept_token) + + def _establish_client_connection(self, event): + if event.status_code != 101: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad status code from server") + headers = _normed_header_dict(event.headers) + if headers[b'connection'].lower() != b'upgrade': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Connection: Upgrade header") + if headers[b'upgrade'].lower() != b'websocket': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Upgrade: WebSocket header") + + accept_token = self._generate_accept_token(self._nonce) + if headers[b'sec-websocket-accept'] != accept_token: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Bad accept token") + + subprotocol = headers.get(b'sec-websocket-protocol', None) + if subprotocol is not None: + subprotocol = subprotocol.decode('ascii') + if subprotocol not in self.subprotocols: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "unrecognized subprotocol {!r}" + .format(subprotocol)) + + extensions = headers.get(b'sec-websocket-extensions', None) + if extensions: + accepts = _split_comma_header(extensions) + + for accept in accepts: + name = accept.split(';', 1)[0].strip() + for extension in self.extensions: + if extension.name == name: + extension.finalize(self, accept) + break + else: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "unrecognized extension {!r}" + .format(name)) + + self._proto = FrameProtocol(self.client, self.extensions) + self._state = ConnectionState.OPEN + return ConnectionEstablished(subprotocol, extensions) + + def _process_connection_request(self, event): + if event.method != b'GET': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Request method must be GET") + headers = _normed_header_dict(event.headers) + if headers[b'connection'].lower() != b'upgrade': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Connection: Upgrade header") + if headers[b'upgrade'].lower() != b'websocket': + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Upgrade: WebSocket header") + + if b'sec-websocket-version' not in headers: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Sec-WebSocket-Version header") + # XX FIXME: need to check Sec-Websocket-Version, and respond with a + # 400 if it's not what we expect + + if b'sec-websocket-protocol' in headers: + proposed_subprotocols = _split_comma_header( + headers[b'sec-websocket-protocol']) + else: + proposed_subprotocols = [] + + if b'sec-websocket-key' not in headers: + return ConnectionFailed(CloseReason.PROTOCOL_ERROR, + "Missing Sec-WebSocket-Key header") + + return ConnectionRequested(proposed_subprotocols, event) + + def _extension_accept(self, extensions_header): + accepts = {} + offers = _split_comma_header(extensions_header) + + for offer in offers: + name = offer.split(';', 1)[0].strip() + for extension in self.extensions: + if extension.name == name: + accept = extension.accept(self, offer) + if accept is True: + accepts[extension.name] = True + elif accept is not False and accept is not None: + accepts[extension.name] = accept.encode('ascii') + + if accepts: + extensions = [] + for name, params in accepts.items(): + if params is True: + extensions.append(name.encode('ascii')) + else: + # py34 annoyance: doesn't support bytestring formatting + params = params.decode("ascii") + extensions.append(('%s; %s' % (name, params)) + .encode("ascii")) + return b', '.join(extensions) + + return None + + def accept(self, event, subprotocol=None): + request = event.h11request + request_headers = _normed_header_dict(request.headers) + + nonce = request_headers[b'sec-websocket-key'] + accept_token = self._generate_accept_token(nonce) + + headers = { + b"Upgrade": b'WebSocket', + b"Connection": b'Upgrade', + b"Sec-WebSocket-Accept": accept_token, + } + + if subprotocol is not None: + if subprotocol not in event.proposed_subprotocols: + raise ValueError( + "unexpected subprotocol {!r}".format(subprotocol)) + headers[b'Sec-WebSocket-Protocol'] = subprotocol + + extensions = request_headers.get(b'sec-websocket-extensions', None) + if extensions: + accepts = self._extension_accept(extensions) + if accepts: + headers[b"Sec-WebSocket-Extensions"] = accepts + + response = h11.InformationalResponse(status_code=101, + headers=headers.items()) + self._outgoing += self._upgrade_connection.send(response) + self._proto = FrameProtocol(self.client, self.extensions) + self._state = ConnectionState.OPEN + + def ping(self, payload=None): + """ + Send a PING message to the peer. + + :param payload: an optional payload to send with the message + """ + + payload = bytes(payload or b'') + self._outgoing += self._proto.ping(payload) + + def pong(self, payload=None): + """ + Send a PONG message to the peer. + + This method can be used to send an unsolicted PONG to the peer. + It is not needed otherwise since every received PING causes a + corresponding PONG to be sent automatically. + + :param payload: an optional payload to send with the message + """ + + payload = bytes(payload or b'') + self._outgoing += self._proto.pong(payload) diff --git a/mitmproxy/contrib/wsproto/events.py b/mitmproxy/contrib/wsproto/events.py new file mode 100644 index 000000000..73ce27aac --- /dev/null +++ b/mitmproxy/contrib/wsproto/events.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +""" +wsproto/events +~~~~~~~~~~ + +Events that result from processing data on a WebSocket connection. +""" + + +class ConnectionRequested(object): + def __init__(self, proposed_subprotocols, h11request): + self.proposed_subprotocols = proposed_subprotocols + self.h11request = h11request + + def __repr__(self): + path = self.h11request.target + + headers = dict(self.h11request.headers) + host = headers[b'host'] + version = headers[b'sec-websocket-version'] + subprotocol = headers.get(b'sec-websocket-protocol', None) + extensions = [] + + fmt = '<%s host=%s path=%s version=%s subprotocol=%r extensions=%r>' + return fmt % (self.__class__.__name__, host, path, version, + subprotocol, extensions) + + +class ConnectionEstablished(object): + def __init__(self, subprotocol=None, extensions=None): + self.subprotocol = subprotocol + self.extensions = extensions + if self.extensions is None: + self.extensions = [] + + def __repr__(self): + return '' % \ + (self.subprotocol, self.extensions) + + +class ConnectionClosed(object): + def __init__(self, code, reason=None): + self.code = code + self.reason = reason + + def __repr__(self): + return '<%s code=%r reason="%s">' % (self.__class__.__name__, + self.code, self.reason) + + +class ConnectionFailed(ConnectionClosed): + pass + + +class DataReceived(object): + def __init__(self, data, frame_finished, message_finished): + self.data = data + # This has no semantic content, but is provided just in case some + # weird edge case user wants to be able to reconstruct the + # fragmentation pattern of the original stream. You don't want it: + self.frame_finished = frame_finished + # This is the field that you almost certainly want: + self.message_finished = message_finished + + +class TextReceived(DataReceived): + pass + + +class BytesReceived(DataReceived): + pass + + +class PingReceived(object): + def __init__(self, payload): + self.payload = payload + + +class PongReceived(object): + def __init__(self, payload): + self.payload = payload diff --git a/mitmproxy/contrib/wsproto/extensions.py b/mitmproxy/contrib/wsproto/extensions.py new file mode 100644 index 000000000..f7cf4fb61 --- /dev/null +++ b/mitmproxy/contrib/wsproto/extensions.py @@ -0,0 +1,257 @@ +# -*- coding: utf-8 -*- +""" +wsproto/extensions +~~~~~~~~~~~~~~ + +WebSocket extensions. +""" + +import zlib + +from .frame_protocol import CloseReason, Opcode, RsvBits + + +class Extension(object): + name = None + + def enabled(self): + return False + + def offer(self, connection): + pass + + def accept(self, connection, offer): + pass + + def finalize(self, connection, offer): + pass + + def frame_inbound_header(self, proto, opcode, rsv, payload_length): + return RsvBits(False, False, False) + + def frame_inbound_payload_data(self, proto, data): + return data + + def frame_inbound_complete(self, proto, fin): + pass + + def frame_outbound(self, proto, opcode, rsv, data, fin): + return (rsv, data) + + +class PerMessageDeflate(Extension): + name = 'permessage-deflate' + + DEFAULT_CLIENT_MAX_WINDOW_BITS = 15 + DEFAULT_SERVER_MAX_WINDOW_BITS = 15 + + def __init__(self, client_no_context_takeover=False, + client_max_window_bits=None, server_no_context_takeover=False, + server_max_window_bits=None): + self.client_no_context_takeover = client_no_context_takeover + if client_max_window_bits is None: + client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS + self.client_max_window_bits = client_max_window_bits + self.server_no_context_takeover = server_no_context_takeover + if server_max_window_bits is None: + server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS + self.server_max_window_bits = server_max_window_bits + + self._compressor = None + self._decompressor = None + # This refers to the current frame + self._inbound_is_compressible = None + # This refers to the ongoing message (which might span multiple + # frames). Only the first frame in a fragmented message is flagged for + # compression, so this carries that bit forward. + self._inbound_compressed = None + + self._enabled = False + + def _compressible_opcode(self, opcode): + return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION) + + def enabled(self): + return self._enabled + + def offer(self, connection): + parameters = [ + 'client_max_window_bits=%d' % self.client_max_window_bits, + 'server_max_window_bits=%d' % self.server_max_window_bits, + ] + + if self.client_no_context_takeover: + parameters.append('client_no_context_takeover') + if self.server_no_context_takeover: + parameters.append('server_no_context_takeover') + + return '; '.join(parameters) + + def finalize(self, connection, offer): + bits = [b.strip() for b in offer.split(';')] + for bit in bits[1:]: + if bit.startswith('client_no_context_takeover'): + self.client_no_context_takeover = True + elif bit.startswith('server_no_context_takeover'): + self.server_no_context_takeover = True + elif bit.startswith('client_max_window_bits'): + self.client_max_window_bits = int(bit.split('=', 1)[1].strip()) + elif bit.startswith('server_max_window_bits'): + self.server_max_window_bits = int(bit.split('=', 1)[1].strip()) + + self._enabled = True + + def _parse_params(self, params): + client_max_window_bits = None + server_max_window_bits = None + + bits = [b.strip() for b in params.split(';')] + for bit in bits[1:]: + if bit.startswith('client_no_context_takeover'): + self.client_no_context_takeover = True + elif bit.startswith('server_no_context_takeover'): + self.server_no_context_takeover = True + elif bit.startswith('client_max_window_bits'): + if '=' in bit: + client_max_window_bits = int(bit.split('=', 1)[1].strip()) + else: + client_max_window_bits = self.client_max_window_bits + elif bit.startswith('server_max_window_bits'): + if '=' in bit: + server_max_window_bits = int(bit.split('=', 1)[1].strip()) + else: + server_max_window_bits = self.server_max_window_bits + + return client_max_window_bits, server_max_window_bits + + def accept(self, connection, offer): + client_max_window_bits, server_max_window_bits = \ + self._parse_params(offer) + + self._enabled = True + + parameters = [] + + if self.client_no_context_takeover: + parameters.append('client_no_context_takeover') + if client_max_window_bits is not None: + parameters.append('client_max_window_bits=%d' % + client_max_window_bits) + self.client_max_window_bits = client_max_window_bits + if self.server_no_context_takeover: + parameters.append('server_no_context_takeover') + if server_max_window_bits is not None: + parameters.append('server_max_window_bits=%d' % + server_max_window_bits) + self.server_max_window_bits = server_max_window_bits + + return '; '.join(parameters) + + def frame_inbound_header(self, proto, opcode, rsv, payload_length): + if rsv.rsv1 and opcode.iscontrol(): + return CloseReason.PROTOCOL_ERROR + elif rsv.rsv1 and opcode is Opcode.CONTINUATION: + return CloseReason.PROTOCOL_ERROR + + self._inbound_is_compressible = self._compressible_opcode(opcode) + + if self._inbound_compressed is None: + self._inbound_compressed = rsv.rsv1 + if self._inbound_compressed: + assert self._inbound_is_compressible + if proto.client: + bits = self.server_max_window_bits + else: + bits = self.client_max_window_bits + if self._decompressor is None: + self._decompressor = zlib.decompressobj(-int(bits)) + + return RsvBits(True, False, False) + + def frame_inbound_payload_data(self, proto, data): + if not self._inbound_compressed or not self._inbound_is_compressible: + return data + + try: + return self._decompressor.decompress(bytes(data)) + except zlib.error: + return CloseReason.INVALID_FRAME_PAYLOAD_DATA + + def frame_inbound_complete(self, proto, fin): + if not fin: + return + elif not self._inbound_is_compressible: + return + elif not self._inbound_compressed: + return + + try: + data = self._decompressor.decompress(b'\x00\x00\xff\xff') + data += self._decompressor.flush() + except zlib.error: + return CloseReason.INVALID_FRAME_PAYLOAD_DATA + + if proto.client: + no_context_takeover = self.server_no_context_takeover + else: + no_context_takeover = self.client_no_context_takeover + + if no_context_takeover: + self._decompressor = None + + self._inbound_compressed = None + + return data + + def frame_outbound(self, proto, opcode, rsv, data, fin): + if not self._compressible_opcode(opcode): + return (rsv, data) + + if opcode is not Opcode.CONTINUATION: + rsv = RsvBits(True, *rsv[1:]) + + if self._compressor is None: + assert opcode is not Opcode.CONTINUATION + if proto.client: + bits = self.client_max_window_bits + else: + bits = self.server_max_window_bits + self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, + zlib.DEFLATED, -int(bits)) + + data = self._compressor.compress(bytes(data)) + + if fin: + data += self._compressor.flush(zlib.Z_SYNC_FLUSH) + data = data[:-4] + + if proto.client: + no_context_takeover = self.client_no_context_takeover + else: + no_context_takeover = self.server_no_context_takeover + + if no_context_takeover: + self._compressor = None + + return (rsv, data) + + def __repr__(self): + descr = ['client_max_window_bits=%d' % self.client_max_window_bits] + if self.client_no_context_takeover: + descr.append('client_no_context_takeover') + descr.append('server_max_window_bits=%d' % self.server_max_window_bits) + if self.server_no_context_takeover: + descr.append('server_no_context_takeover') + + descr = '; '.join(descr) + + return '<%s %s>' % (self.__class__.__name__, descr) + + +#: SUPPORTED_EXTENSIONS maps all supported extension names to their class. +#: This can be used to iterate all supported extensions of wsproto, instantiate +#: new extensions based on their name, or check if a given extension is +#: supported or not. +SUPPORTED_EXTENSIONS = { + PerMessageDeflate.name: PerMessageDeflate +} diff --git a/mitmproxy/contrib/wsproto/frame_protocol.py b/mitmproxy/contrib/wsproto/frame_protocol.py new file mode 100644 index 000000000..b95dceec2 --- /dev/null +++ b/mitmproxy/contrib/wsproto/frame_protocol.py @@ -0,0 +1,579 @@ +# -*- coding: utf-8 -*- +""" +wsproto/frame_protocol +~~~~~~~~~~~~~~ + +WebSocket frame protocol implementation. +""" + +import os +import itertools +import struct +from codecs import getincrementaldecoder +from collections import namedtuple + +from enum import Enum, IntEnum + +from .compat import unicode, Utf8Validator + +try: + from wsaccel.xormask import XorMaskerSimple +except ImportError: + class XorMaskerSimple: + def __init__(self, masking_key): + self._maskbytes = itertools.cycle(bytearray(masking_key)) + + def process(self, data): + maskbytes = self._maskbytes + return bytearray(b ^ next(maskbytes) for b in bytearray(data)) + + +class XorMaskerNull: + def process(self, data): + return data + + +# RFC6455, Section 5.2 - Base Framing Protocol + +# Payload length constants +PAYLOAD_LENGTH_TWO_BYTE = 126 +PAYLOAD_LENGTH_EIGHT_BYTE = 127 +MAX_PAYLOAD_NORMAL = 125 +MAX_PAYLOAD_TWO_BYTE = 2 ** 16 - 1 +MAX_PAYLOAD_EIGHT_BYTE = 2 ** 64 - 1 +MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE + +# MASK and PAYLOAD LEN are packed into a byte +MASK_MASK = 0x80 +PAYLOAD_LEN_MASK = 0x7f + +# FIN, RSV[123] and OPCODE are packed into a single byte +FIN_MASK = 0x80 +RSV1_MASK = 0x40 +RSV2_MASK = 0x20 +RSV3_MASK = 0x10 +OPCODE_MASK = 0x0f + + +class Opcode(IntEnum): + """ + RFC 6455, Section 5.2 - Base Framing Protocol + """ + CONTINUATION = 0x0 + TEXT = 0x1 + BINARY = 0x2 + CLOSE = 0x8 + PING = 0x9 + PONG = 0xA + + def iscontrol(self): + return bool(self & 0x08) + + +class CloseReason(IntEnum): + """ + RFC 6455, Section 7.4.1 - Defined Status Codes + """ + NORMAL_CLOSURE = 1000 + GOING_AWAY = 1001 + PROTOCOL_ERROR = 1002 + UNSUPPORTED_DATA = 1003 + NO_STATUS_RCVD = 1005 + ABNORMAL_CLOSURE = 1006 + INVALID_FRAME_PAYLOAD_DATA = 1007 + POLICY_VIOLATION = 1008 + MESSAGE_TOO_BIG = 1009 + MANDATORY_EXT = 1010 + INTERNAL_ERROR = 1011 + SERVICE_RESTART = 1012 + TRY_AGAIN_LATER = 1013 + TLS_HANDSHAKE_FAILED = 1015 + + +# RFC 6455, Section 7.4.1 - Defined Status Codes +LOCAL_ONLY_CLOSE_REASONS = ( + CloseReason.NO_STATUS_RCVD, + CloseReason.ABNORMAL_CLOSURE, + CloseReason.TLS_HANDSHAKE_FAILED, +) + + +# RFC 6455, Section 7.4.2 - Status Code Ranges +MIN_CLOSE_REASON = 1000 +MIN_PROTOCOL_CLOSE_REASON = 1000 +MAX_PROTOCOL_CLOSE_REASON = 2999 +MIN_LIBRARY_CLOSE_REASON = 3000 +MAX_LIBRARY_CLOSE_REASON = 3999 +MIN_PRIVATE_CLOSE_REASON = 4000 +MAX_PRIVATE_CLOSE_REASON = 4999 +MAX_CLOSE_REASON = 4999 + + +NULL_MASK = struct.pack("!I", 0) + + +class ParseFailed(Exception): + def __init__(self, msg, code=CloseReason.PROTOCOL_ERROR): + super(ParseFailed, self).__init__(msg) + self.code = code + + +Header = namedtuple("Header", "fin rsv opcode payload_len masking_key".split()) + + +Frame = namedtuple("Frame", + "opcode payload frame_finished message_finished".split()) + + +RsvBits = namedtuple("RsvBits", "rsv1 rsv2 rsv3".split()) + + +def _truncate_utf8(data, nbytes): + if len(data) <= nbytes: + return data + + # Truncate + data = data[:nbytes] + # But we might have cut a codepoint in half, in which case we want to + # discard the partial character so the data is at least + # well-formed. This is a little inefficient since it processes the + # whole message twice when in theory we could just peek at the last + # few characters, but since this is only used for close messages (max + # length = 125 bytes) it really doesn't matter. + data = data.decode("utf-8", errors="ignore").encode("utf-8") + return data + + +class Buffer(object): + def __init__(self, initial_bytes=None): + self.buffer = bytearray() + self.bytes_used = 0 + if initial_bytes: + self.feed(initial_bytes) + + def feed(self, new_bytes): + self.buffer += new_bytes + + def consume_at_most(self, nbytes): + if not nbytes: + return bytearray() + + data = self.buffer[self.bytes_used:self.bytes_used + nbytes] + self.bytes_used += len(data) + return data + + def consume_exactly(self, nbytes): + if len(self.buffer) - self.bytes_used < nbytes: + return None + + return self.consume_at_most(nbytes) + + def commit(self): + # In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic + del self.buffer[:self.bytes_used] + self.bytes_used = 0 + + def rollback(self): + self.bytes_used = 0 + + def __len__(self): + return len(self.buffer) + + +class MessageDecoder(object): + def __init__(self): + self.opcode = None + self.validator = None + self.decoder = None + + def process_frame(self, frame): + assert not frame.opcode.iscontrol() + + if self.opcode is None: + if frame.opcode is Opcode.CONTINUATION: + raise ParseFailed("unexpected CONTINUATION") + self.opcode = frame.opcode + elif frame.opcode is not Opcode.CONTINUATION: + raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode) + + if frame.opcode is Opcode.TEXT: + self.validator = Utf8Validator() + self.decoder = getincrementaldecoder("utf-8")() + + finished = frame.frame_finished and frame.message_finished + + if self.decoder is not None: + data = self.decode_payload(frame.payload, finished) + else: + data = frame.payload + + frame = Frame(self.opcode, data, frame.frame_finished, finished) + + if finished: + self.opcode = None + self.decoder = None + + return frame + + def decode_payload(self, data, finished): + if self.validator is not None: + results = self.validator.validate(bytes(data)) + if not results[0] or (finished and not results[1]): + raise ParseFailed(u'encountered invalid UTF-8 while processing' + ' text message at payload octet index %d' % + results[3], + CloseReason.INVALID_FRAME_PAYLOAD_DATA) + + try: + return self.decoder.decode(data, finished) + except UnicodeDecodeError as exc: + raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA) + + +class FrameDecoder(object): + def __init__(self, client, extensions=None): + self.client = client + self.extensions = extensions or [] + + self.buffer = Buffer() + + self.header = None + self.effective_opcode = None + self.masker = None + self.payload_required = 0 + self.payload_consumed = 0 + + def receive_bytes(self, data): + self.buffer.feed(data) + + def process_buffer(self): + if not self.header: + if not self.parse_header(): + return None + + if len(self.buffer) < self.payload_required: + return None + + payload_remaining = self.header.payload_len - self.payload_consumed + payload = self.buffer.consume_at_most(payload_remaining) + if not payload and self.header.payload_len > 0: + return None + self.buffer.commit() + + self.payload_consumed += len(payload) + finished = self.payload_consumed == self.header.payload_len + + payload = self.masker.process(payload) + + for extension in self.extensions: + payload = extension.frame_inbound_payload_data(self, payload) + if isinstance(payload, CloseReason): + raise ParseFailed("error in extension", payload) + + if finished: + final = bytearray() + for extension in self.extensions: + result = extension.frame_inbound_complete(self, + self.header.fin) + if isinstance(result, CloseReason): + raise ParseFailed("error in extension", result) + if result is not None: + final += result + payload += final + + frame = Frame(self.effective_opcode, payload, finished, + self.header.fin) + + if finished: + self.header = None + self.effective_opcode = None + self.masker = None + else: + self.effective_opcode = Opcode.CONTINUATION + + return frame + + def parse_header(self): + data = self.buffer.consume_exactly(2) + if data is None: + self.buffer.rollback() + return False + + fin = bool(data[0] & FIN_MASK) + rsv = RsvBits(bool(data[0] & RSV1_MASK), + bool(data[0] & RSV2_MASK), + bool(data[0] & RSV3_MASK)) + opcode = data[0] & OPCODE_MASK + try: + opcode = Opcode(opcode) + except ValueError: + raise ParseFailed("Invalid opcode {:#x}".format(opcode)) + + if opcode.iscontrol() and not fin: + raise ParseFailed("Invalid attempt to fragment control frame") + + has_mask = bool(data[1] & MASK_MASK) + payload_len = data[1] & PAYLOAD_LEN_MASK + payload_len = self.parse_extended_payload_length(opcode, payload_len) + if payload_len is None: + self.buffer.rollback() + return False + + self.extension_processing(opcode, rsv, payload_len) + + if has_mask and self.client: + raise ParseFailed("client received unexpected masked frame") + if not has_mask and not self.client: + raise ParseFailed("server received unexpected unmasked frame") + if has_mask: + masking_key = self.buffer.consume_exactly(4) + if masking_key is None: + self.buffer.rollback() + return False + self.masker = XorMaskerSimple(masking_key) + else: + self.masker = XorMaskerNull() + + self.buffer.commit() + self.header = Header(fin, rsv, opcode, payload_len, None) + self.effective_opcode = self.header.opcode + if self.header.opcode.iscontrol(): + self.payload_required = payload_len + else: + self.payload_required = 0 + self.payload_consumed = 0 + return True + + def parse_extended_payload_length(self, opcode, payload_len): + if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL: + raise ParseFailed("Control frame with payload len > 125") + if payload_len == PAYLOAD_LENGTH_TWO_BYTE: + data = self.buffer.consume_exactly(2) + if data is None: + return None + (payload_len,) = struct.unpack("!H", data) + if payload_len <= MAX_PAYLOAD_NORMAL: + raise ParseFailed( + "Payload length used 2 bytes when 1 would have sufficed") + elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE: + data = self.buffer.consume_exactly(8) + if data is None: + return None + (payload_len,) = struct.unpack("!Q", data) + if payload_len <= MAX_PAYLOAD_TWO_BYTE: + raise ParseFailed( + "Payload length used 8 bytes when 2 would have sufficed") + if payload_len >> 63: + # I'm not sure why this is illegal, but that's what the RFC + # says, so... + raise ParseFailed("8-byte payload length with non-zero MSB") + + return payload_len + + def extension_processing(self, opcode, rsv, payload_len): + rsv_used = [False, False, False] + for extension in self.extensions: + result = extension.frame_inbound_header(self, opcode, rsv, + payload_len) + if isinstance(result, CloseReason): + raise ParseFailed("error in extension", result) + for bit, used in enumerate(result): + if used: + rsv_used[bit] = True + for expected, found in zip(rsv_used, rsv): + if found and not expected: + raise ParseFailed("Reserved bit set unexpectedly") + + +class FrameProtocol(object): + class State(Enum): + HEADER = 1 + PAYLOAD = 2 + FRAME_COMPLETE = 3 + FAILED = 4 + + def __init__(self, client, extensions): + self.client = client + self.extensions = [ext for ext in extensions if ext.enabled()] + + # Global state + self._frame_decoder = FrameDecoder(self.client, self.extensions) + self._message_decoder = MessageDecoder() + self._parse_more = self.parse_more_gen() + + self._outbound_opcode = None + + def _process_close(self, frame): + data = frame.payload + + if not data: + # "If this Close control frame contains no status code, _The + # WebSocket Connection Close Code_ is considered to be 1005" + data = (CloseReason.NO_STATUS_RCVD, "") + elif len(data) == 1: + raise ParseFailed("CLOSE with 1 byte payload") + else: + (code,) = struct.unpack("!H", data[:2]) + if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON: + raise ParseFailed("CLOSE with invalid code") + try: + code = CloseReason(code) + except ValueError: + pass + if code in LOCAL_ONLY_CLOSE_REASONS: + raise ParseFailed( + "remote CLOSE with local-only reason") + if not isinstance(code, CloseReason) and \ + code <= MAX_PROTOCOL_CLOSE_REASON: + raise ParseFailed( + "CLOSE with unknown reserved code") + validator = Utf8Validator() + if validator is not None: + results = validator.validate(bytes(data[2:])) + if not (results[0] and results[1]): + raise ParseFailed(u'encountered invalid UTF-8 while' + ' processing close message at payload' + ' octet index %d' % + results[3], + CloseReason.INVALID_FRAME_PAYLOAD_DATA) + try: + reason = data[2:].decode("utf-8") + except UnicodeDecodeError as exc: + raise ParseFailed( + "Error decoding CLOSE reason: " + str(exc), + CloseReason.INVALID_FRAME_PAYLOAD_DATA) + data = (code, reason) + + return Frame(frame.opcode, data, frame.frame_finished, + frame.message_finished) + + def parse_more_gen(self): + # Consume as much as we can from self._buffer, yielding events, and + # then yield None when we need more data. Or raise ParseFailed. + + # XX FIXME this should probably be refactored so that we never see + # disabled extensions in the first place... + self.extensions = [ext for ext in self.extensions if ext.enabled()] + closed = False + + while not closed: + frame = self._frame_decoder.process_buffer() + + if frame is not None: + if not frame.opcode.iscontrol(): + frame = self._message_decoder.process_frame(frame) + elif frame.opcode == Opcode.CLOSE: + frame = self._process_close(frame) + closed = True + + yield frame + + def receive_bytes(self, data): + self._frame_decoder.receive_bytes(data) + + def received_frames(self): + for event in self._parse_more: + if event is None: + break + else: + yield event + + def close(self, code=None, reason=None): + payload = bytearray() + if code is None and reason is not None: + raise TypeError("cannot specify a reason without a code") + if code in LOCAL_ONLY_CLOSE_REASONS: + code = CloseReason.NORMAL_CLOSURE + if code is not None: + payload += bytearray(struct.pack('!H', code)) + if reason is not None: + payload += _truncate_utf8(reason.encode('utf-8'), + MAX_PAYLOAD_NORMAL - 2) + + return self._serialize_frame(Opcode.CLOSE, payload) + + def ping(self, payload=b''): + return self._serialize_frame(Opcode.PING, payload) + + def pong(self, payload=b''): + return self._serialize_frame(Opcode.PONG, payload) + + def send_data(self, payload=b'', fin=True): + if isinstance(payload, (bytes, bytearray, memoryview)): + opcode = Opcode.BINARY + elif isinstance(payload, unicode): + opcode = Opcode.TEXT + payload = payload.encode('utf-8') + else: + raise ValueError('Must provide bytes or text') + + if self._outbound_opcode is None: + self._outbound_opcode = opcode + elif self._outbound_opcode is not opcode: + raise TypeError('Data type mismatch inside message') + else: + opcode = Opcode.CONTINUATION + + if fin: + self._outbound_opcode = None + + return self._serialize_frame(opcode, payload, fin) + + def _make_fin_rsv_opcode(self, fin, rsv, opcode): + fin = int(fin) << 7 + rsv = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + \ + (int(rsv.rsv3) << 4) + opcode = int(opcode) + + return fin | rsv | opcode + + def _serialize_frame(self, opcode, payload=b'', fin=True): + rsv = RsvBits(False, False, False) + for extension in reversed(self.extensions): + rsv, payload = extension.frame_outbound(self, opcode, rsv, payload, + fin) + + fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode) + + payload_length = len(payload) + quad_payload = False + if payload_length <= MAX_PAYLOAD_NORMAL: + first_payload = payload_length + second_payload = None + elif payload_length <= MAX_PAYLOAD_TWO_BYTE: + first_payload = PAYLOAD_LENGTH_TWO_BYTE + second_payload = payload_length + else: + first_payload = PAYLOAD_LENGTH_EIGHT_BYTE + second_payload = payload_length + quad_payload = True + + if self.client: + first_payload |= 1 << 7 + + header = bytearray([fin_rsv_opcode, first_payload]) + if second_payload is not None: + if opcode.iscontrol(): + raise ValueError("payload too long for control frame") + if quad_payload: + header += bytearray(struct.pack('!Q', second_payload)) + else: + header += bytearray(struct.pack('!H', second_payload)) + + if self.client: + # "The masking key is a 32-bit value chosen at random by the + # client. When preparing a masked frame, the client MUST pick a + # fresh masking key from the set of allowed 32-bit values. The + # masking key needs to be unpredictable; thus, the masking key + # MUST be derived from a strong source of entropy, and the masking + # key for a given frame MUST NOT make it simple for a server/proxy + # to predict the masking key for a subsequent frame. The + # unpredictability of the masking key is essential to prevent + # authors of malicious applications from selecting the bytes that + # appear on the wire." + # -- https://tools.ietf.org/html/rfc6455#section-5.3 + masking_key = os.urandom(4) + masker = XorMaskerSimple(masking_key) + return header + masking_key + masker.process(payload) + + return header + payload diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 54d8120de..34dcba066 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -1,10 +1,10 @@ import socket from OpenSSL import SSL -from wsproto import events -from wsproto.connection import ConnectionType, WSConnection -from wsproto.extensions import PerMessageDeflate -from wsproto.frame_protocol import Opcode +from mitmproxy.contrib.wsproto import events +from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection +from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate +from mitmproxy.contrib.wsproto.frame_protocol import Opcode from mitmproxy import exceptions from mitmproxy import flow diff --git a/setup.py b/setup.py index 54c2811d9..ad792881e 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ setup( "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! "click>=6.2, <7", "cryptography>=2.0,<2.2", + 'h11>=0.7.0,<0.8', "h2>=3.0, <4", "hyperframe>=5.0, <6", "kaitaistruct>=0.7, <0.8",