From 4495562f86de5c3f41a0cf1616c08e74f45b20dd Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 1 Jan 2018 11:16:52 +0100 Subject: [PATCH] unvendor wsproto --- mitmproxy/contrib/wsproto/__init__.py | 13 - mitmproxy/contrib/wsproto/compat.py | 20 - mitmproxy/contrib/wsproto/connection.py | 477 ---------------- mitmproxy/contrib/wsproto/events.py | 81 --- mitmproxy/contrib/wsproto/extensions.py | 259 --------- mitmproxy/contrib/wsproto/frame_protocol.py | 581 -------------------- mitmproxy/proxy/protocol/websocket.py | 8 +- mitmproxy/websocket.py | 11 +- setup.py | 1 + 9 files changed, 11 insertions(+), 1440 deletions(-) delete mode 100644 mitmproxy/contrib/wsproto/__init__.py delete mode 100644 mitmproxy/contrib/wsproto/compat.py delete mode 100644 mitmproxy/contrib/wsproto/connection.py delete mode 100644 mitmproxy/contrib/wsproto/events.py delete mode 100644 mitmproxy/contrib/wsproto/extensions.py delete mode 100644 mitmproxy/contrib/wsproto/frame_protocol.py diff --git a/mitmproxy/contrib/wsproto/__init__.py b/mitmproxy/contrib/wsproto/__init__.py deleted file mode 100644 index d0592bc53..000000000 --- a/mitmproxy/contrib/wsproto/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from . import compat -from . import connection -from . import events -from . import extensions -from . import frame_protocol - -__all__ = [ - 'compat', - 'connection', - 'events', - 'extensions', - 'frame_protocol', -] diff --git a/mitmproxy/contrib/wsproto/compat.py b/mitmproxy/contrib/wsproto/compat.py deleted file mode 100644 index 1911f83cf..000000000 --- a/mitmproxy/contrib/wsproto/compat.py +++ /dev/null @@ -1,20 +0,0 @@ -# flake8: noqa - -import sys - - -PY2 = sys.version_info.major == 2 -PY3 = sys.version_info.major == 3 - - -if PY3: - unicode = str - - def Utf8Validator(): - return None -else: - unicode = unicode - try: - from wsaccel.utf8validator import Utf8Validator - except ImportError: - from .utf8validator import Utf8Validator diff --git a/mitmproxy/contrib/wsproto/connection.py b/mitmproxy/contrib/wsproto/connection.py deleted file mode 100644 index f994cd3ab..000000000 --- a/mitmproxy/contrib/wsproto/connection.py +++ /dev/null @@ -1,477 +0,0 @@ -# -*- coding: utf-8 -*- -""" -wsproto/connection -~~~~~~~~~~~~~~ - -An implementation of a WebSocket connection. -""" - -import os -import base64 -import hashlib -from collections import deque - -from enum import Enum - -import h11 - -from .events import ( - ConnectionRequested, ConnectionEstablished, ConnectionClosed, - ConnectionFailed, TextReceived, BytesReceived, PingReceived, PongReceived -) -from .frame_protocol import FrameProtocol, ParseFailed, CloseReason, Opcode - - -# RFC6455, Section 1.3 - Opening Handshake -ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - - -class ConnectionState(Enum): - """ - RFC 6455, Section 4 - Opening Handshake - """ - CONNECTING = 0 - OPEN = 1 - CLOSING = 2 - CLOSED = 3 - - -class ConnectionType(Enum): - CLIENT = 1 - SERVER = 2 - - -CLIENT = ConnectionType.CLIENT -SERVER = ConnectionType.SERVER - - -# Some convenience utilities for working with HTTP headers -def _normed_header_dict(h11_headers): - # This mangles Set-Cookie headers. But it happens that we don't care about - # any of those, so it's OK. For every other HTTP header, if there are - # multiple instances then you're allowed to join them together with - # commas. - name_to_values = {} - for name, value in h11_headers: - name_to_values.setdefault(name, []).append(value) - name_to_normed_value = {} - for name, values in name_to_values.items(): - name_to_normed_value[name] = b", ".join(values) - return name_to_normed_value - - -# We use this for parsing the proposed protocol list, and for parsing the -# proposed and accepted extension lists. For the proposed protocol list it's -# fine, because the ABNF is just 1#token. But for the extension lists, it's -# wrong, because those can contain quoted strings, which can in turn contain -# commas. XX FIXME -def _split_comma_header(value): - return [piece.decode('ascii').strip() for piece in value.split(b',')] - - -class WSConnection(object): - """ - A low-level WebSocket connection object. - - This wraps two other protocol objects, an HTTP/1.1 protocol object used - to do the initial HTTP upgrade handshake and a WebSocket frame protocol - object used to exchange messages and other control frames. - - :param conn_type: Whether this object is on the client- or server-side of - a connection. To initialise as a client pass ``CLIENT`` otherwise - pass ``SERVER``. - :type conn_type: ``ConnectionType`` - - :param host: The hostname to pass to the server when acting as a client. - :type host: ``str`` - - :param resource: The resource (aka path) to pass to the server when acting - as a client. - :type resource: ``str`` - - :param extensions: A list of extensions to use on this connection. - Extensions should be instances of a subclass of - :class:`Extension `. - - :param subprotocols: A list of subprotocols to request when acting as a - client, ordered by preference. This has no impact on the connection - itself. - :type subprotocol: ``list`` of ``str`` - """ - - def __init__(self, conn_type, host=None, resource=None, extensions=None, - subprotocols=None): - self.client = conn_type is ConnectionType.CLIENT - - self.host = host - self.resource = resource - - self.subprotocols = subprotocols or [] - self.extensions = extensions or [] - - self.version = b'13' - - self._state = ConnectionState.CONNECTING - self._close_reason = None - - self._nonce = None - self._outgoing = b'' - self._events = deque() - self._proto = None - - if self.client: - self._upgrade_connection = h11.Connection(h11.CLIENT) - else: - self._upgrade_connection = h11.Connection(h11.SERVER) - - if self.client: - if self.host is None: - raise ValueError( - "Host must not be None for a client-side connection.") - if self.resource is None: - raise ValueError( - "Resource must not be None for a client-side connection.") - self.initiate_connection() - - def initiate_connection(self): - self._generate_nonce() - - headers = { - b"Host": self.host.encode('ascii'), - b"Upgrade": b'WebSocket', - b"Connection": b'Upgrade', - b"Sec-WebSocket-Key": self._nonce, - b"Sec-WebSocket-Version": self.version, - } - - if self.subprotocols: - headers[b"Sec-WebSocket-Protocol"] = ", ".join(self.subprotocols) - - if self.extensions: - offers = {e.name: e.offer(self) for e in self.extensions} - extensions = [] - for name, params in offers.items(): - if params is True: - extensions.append(name.encode('ascii')) - elif params: - # py34 annoyance: doesn't support bytestring formatting - extensions.append(('%s; %s' % (name, params)) - .encode("ascii")) - if extensions: - headers[b'Sec-WebSocket-Extensions'] = b', '.join(extensions) - - upgrade = h11.Request(method=b'GET', target=self.resource, - headers=headers.items()) - self._outgoing += self._upgrade_connection.send(upgrade) - - def send_data(self, payload, final=True): - """ - Send a message or part of a message to the remote peer. - - If ``final`` is ``False`` it indicates that this is part of a longer - message. If ``final`` is ``True`` it indicates that this is either a - self-contained message or the last part of a longer message. - - If ``payload`` is of type ``bytes`` then the message is flagged as - being binary If it is of type ``str`` encoded as UTF-8 and sent as - text. - - :param payload: The message body to send. - :type payload: ``bytes`` or ``str`` - - :param final: Whether there are more parts to this message to be sent. - :type final: ``bool`` - """ - - self._outgoing += self._proto.send_data(payload, final) - - def close(self, code=CloseReason.NORMAL_CLOSURE, reason=None): - self._outgoing += self._proto.close(code, reason) - self._state = ConnectionState.CLOSING - - @property - def closed(self): - return self._state is ConnectionState.CLOSED - - def bytes_to_send(self, amount=None): - """ - Return any data that is to be sent to the remote peer. - - :param amount: (optional) The maximum number of bytes to be provided. - If ``None`` or not provided it will return all available bytes. - :type amount: ``int`` - """ - - if amount is None: - data = self._outgoing - self._outgoing = b'' - else: - data = self._outgoing[:amount] - self._outgoing = self._outgoing[amount:] - - return data - - def receive_bytes(self, data): - """ - Pass some received bytes to the connection for processing. - - :param data: The data received from the remote peer. - :type data: ``bytes`` - """ - - if data is None and self._state is ConnectionState.OPEN: - # "If _The WebSocket Connection is Closed_ and no Close control - # frame was received by the endpoint (such as could occur if the - # underlying transport connection is lost), _The WebSocket - # Connection Close Code_ is considered to be 1006." - self._events.append(ConnectionClosed(CloseReason.ABNORMAL_CLOSURE)) - self._state = ConnectionState.CLOSED - return - elif data is None: - self._state = ConnectionState.CLOSED - return - - if self._state is ConnectionState.CONNECTING: - event, data = self._process_upgrade(data) - if event is not None: - self._events.append(event) - - if self._state is ConnectionState.OPEN: - self._proto.receive_bytes(data) - - def _process_upgrade(self, data): - self._upgrade_connection.receive_data(data) - while True: - try: - event = self._upgrade_connection.next_event() - except h11.RemoteProtocolError: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Bad HTTP message"), b'' - if event is h11.NEED_DATA: - break - elif self.client and isinstance(event, (h11.InformationalResponse, - h11.Response)): - data = self._upgrade_connection.trailing_data[0] - return self._establish_client_connection(event), data - elif not self.client and isinstance(event, h11.Request): - return self._process_connection_request(event), None - else: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Bad HTTP message"), b'' - - self._incoming = b'' - return None, None - - def events(self): - """ - Return a generator that provides any events that have been generated - by protocol activity. - - :returns: generator - """ - - while self._events: - yield self._events.popleft() - - if self._proto is None: - return - - try: - for frame in self._proto.received_frames(): - if frame.opcode is Opcode.PING: - assert frame.frame_finished and frame.message_finished - self._outgoing += self._proto.pong(frame.payload) - yield PingReceived(frame.payload) - - elif frame.opcode is Opcode.PONG: - assert frame.frame_finished and frame.message_finished - yield PongReceived(frame.payload) - - elif frame.opcode is Opcode.CLOSE: - code, reason = frame.payload - self.close(code, reason) - yield ConnectionClosed(code, reason) - - elif frame.opcode is Opcode.TEXT: - yield TextReceived(frame.payload, - frame.frame_finished, - frame.message_finished) - - elif frame.opcode is Opcode.BINARY: - yield BytesReceived(frame.payload, - frame.frame_finished, - frame.message_finished) - except ParseFailed as exc: - # XX FIXME: apparently autobahn intentionally deviates from the - # spec in that on protocol errors it just closes the connection - # rather than trying to send a CLOSE frame. Investigate whether we - # should do the same. - self.close(code=exc.code, reason=str(exc)) - yield ConnectionClosed(exc.code, reason=str(exc)) - - def _generate_nonce(self): - # os.urandom may be overkill for this use case, but I don't think this - # is a bottleneck, and better safe than sorry... - self._nonce = base64.b64encode(os.urandom(16)) - - def _generate_accept_token(self, token): - accept_token = token + ACCEPT_GUID - accept_token = hashlib.sha1(accept_token).digest() - return base64.b64encode(accept_token) - - def _establish_client_connection(self, event): - if event.status_code != 101: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Bad status code from server") - headers = _normed_header_dict(event.headers) - if headers[b'connection'].lower() != b'upgrade': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Connection: Upgrade header") - if headers[b'upgrade'].lower() != b'websocket': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Upgrade: WebSocket header") - - accept_token = self._generate_accept_token(self._nonce) - if headers[b'sec-websocket-accept'] != accept_token: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Bad accept token") - - subprotocol = headers.get(b'sec-websocket-protocol', None) - if subprotocol is not None: - subprotocol = subprotocol.decode('ascii') - if subprotocol not in self.subprotocols: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "unrecognized subprotocol {!r}" - .format(subprotocol)) - - extensions = headers.get(b'sec-websocket-extensions', None) - if extensions: - accepts = _split_comma_header(extensions) - - for accept in accepts: - name = accept.split(';', 1)[0].strip() - for extension in self.extensions: - if extension.name == name: - extension.finalize(self, accept) - break - else: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "unrecognized extension {!r}" - .format(name)) - - self._proto = FrameProtocol(self.client, self.extensions) - self._state = ConnectionState.OPEN - return ConnectionEstablished(subprotocol, extensions) - - def _process_connection_request(self, event): - if event.method != b'GET': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Request method must be GET") - headers = _normed_header_dict(event.headers) - if headers[b'connection'].lower() != b'upgrade': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Connection: Upgrade header") - if headers[b'upgrade'].lower() != b'websocket': - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Upgrade: WebSocket header") - - if b'sec-websocket-version' not in headers: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Sec-WebSocket-Version header") - # XX FIXME: need to check Sec-Websocket-Version, and respond with a - # 400 if it's not what we expect - - if b'sec-websocket-protocol' in headers: - proposed_subprotocols = _split_comma_header( - headers[b'sec-websocket-protocol']) - else: - proposed_subprotocols = [] - - if b'sec-websocket-key' not in headers: - return ConnectionFailed(CloseReason.PROTOCOL_ERROR, - "Missing Sec-WebSocket-Key header") - - return ConnectionRequested(proposed_subprotocols, event) - - def _extension_accept(self, extensions_header): - accepts = {} - offers = _split_comma_header(extensions_header) - - for offer in offers: - name = offer.split(';', 1)[0].strip() - for extension in self.extensions: - if extension.name == name: - accept = extension.accept(self, offer) - if accept is True: - accepts[extension.name] = True - elif accept is not False and accept is not None: - accepts[extension.name] = accept.encode('ascii') - - if accepts: - extensions = [] - for name, params in accepts.items(): - if params is True: - extensions.append(name.encode('ascii')) - else: - # py34 annoyance: doesn't support bytestring formatting - params = params.decode("ascii") - extensions.append(('%s; %s' % (name, params)) - .encode("ascii")) - return b', '.join(extensions) - - return None - - def accept(self, event, subprotocol=None): - request = event.h11request - request_headers = _normed_header_dict(request.headers) - - nonce = request_headers[b'sec-websocket-key'] - accept_token = self._generate_accept_token(nonce) - - headers = { - b"Upgrade": b'WebSocket', - b"Connection": b'Upgrade', - b"Sec-WebSocket-Accept": accept_token, - } - - if subprotocol is not None: - if subprotocol not in event.proposed_subprotocols: - raise ValueError( - "unexpected subprotocol {!r}".format(subprotocol)) - headers[b'Sec-WebSocket-Protocol'] = subprotocol - - extensions = request_headers.get(b'sec-websocket-extensions', None) - if extensions: - accepts = self._extension_accept(extensions) - if accepts: - headers[b"Sec-WebSocket-Extensions"] = accepts - - response = h11.InformationalResponse(status_code=101, - headers=headers.items()) - self._outgoing += self._upgrade_connection.send(response) - self._proto = FrameProtocol(self.client, self.extensions) - self._state = ConnectionState.OPEN - - def ping(self, payload=None): - """ - Send a PING message to the peer. - - :param payload: an optional payload to send with the message - """ - - payload = bytes(payload or b'') - self._outgoing += self._proto.ping(payload) - - def pong(self, payload=None): - """ - Send a PONG message to the peer. - - This method can be used to send an unsolicted PONG to the peer. - It is not needed otherwise since every received PING causes a - corresponding PONG to be sent automatically. - - :param payload: an optional payload to send with the message - """ - - payload = bytes(payload or b'') - self._outgoing += self._proto.pong(payload) diff --git a/mitmproxy/contrib/wsproto/events.py b/mitmproxy/contrib/wsproto/events.py deleted file mode 100644 index 73ce27aac..000000000 --- a/mitmproxy/contrib/wsproto/events.py +++ /dev/null @@ -1,81 +0,0 @@ -# -*- coding: utf-8 -*- -""" -wsproto/events -~~~~~~~~~~ - -Events that result from processing data on a WebSocket connection. -""" - - -class ConnectionRequested(object): - def __init__(self, proposed_subprotocols, h11request): - self.proposed_subprotocols = proposed_subprotocols - self.h11request = h11request - - def __repr__(self): - path = self.h11request.target - - headers = dict(self.h11request.headers) - host = headers[b'host'] - version = headers[b'sec-websocket-version'] - subprotocol = headers.get(b'sec-websocket-protocol', None) - extensions = [] - - fmt = '<%s host=%s path=%s version=%s subprotocol=%r extensions=%r>' - return fmt % (self.__class__.__name__, host, path, version, - subprotocol, extensions) - - -class ConnectionEstablished(object): - def __init__(self, subprotocol=None, extensions=None): - self.subprotocol = subprotocol - self.extensions = extensions - if self.extensions is None: - self.extensions = [] - - def __repr__(self): - return '' % \ - (self.subprotocol, self.extensions) - - -class ConnectionClosed(object): - def __init__(self, code, reason=None): - self.code = code - self.reason = reason - - def __repr__(self): - return '<%s code=%r reason="%s">' % (self.__class__.__name__, - self.code, self.reason) - - -class ConnectionFailed(ConnectionClosed): - pass - - -class DataReceived(object): - def __init__(self, data, frame_finished, message_finished): - self.data = data - # This has no semantic content, but is provided just in case some - # weird edge case user wants to be able to reconstruct the - # fragmentation pattern of the original stream. You don't want it: - self.frame_finished = frame_finished - # This is the field that you almost certainly want: - self.message_finished = message_finished - - -class TextReceived(DataReceived): - pass - - -class BytesReceived(DataReceived): - pass - - -class PingReceived(object): - def __init__(self, payload): - self.payload = payload - - -class PongReceived(object): - def __init__(self, payload): - self.payload = payload diff --git a/mitmproxy/contrib/wsproto/extensions.py b/mitmproxy/contrib/wsproto/extensions.py deleted file mode 100644 index 0e0d20184..000000000 --- a/mitmproxy/contrib/wsproto/extensions.py +++ /dev/null @@ -1,259 +0,0 @@ -# type: ignore - -# -*- coding: utf-8 -*- -""" -wsproto/extensions -~~~~~~~~~~~~~~ - -WebSocket extensions. -""" - -import zlib - -from .frame_protocol import CloseReason, Opcode, RsvBits - - -class Extension(object): - name = None - - def enabled(self): - return False - - def offer(self, connection): - pass - - def accept(self, connection, offer): - pass - - def finalize(self, connection, offer): - pass - - def frame_inbound_header(self, proto, opcode, rsv, payload_length): - return RsvBits(False, False, False) - - def frame_inbound_payload_data(self, proto, data): - return data - - def frame_inbound_complete(self, proto, fin): - pass - - def frame_outbound(self, proto, opcode, rsv, data, fin): - return (rsv, data) - - -class PerMessageDeflate(Extension): - name = 'permessage-deflate' - - DEFAULT_CLIENT_MAX_WINDOW_BITS = 15 - DEFAULT_SERVER_MAX_WINDOW_BITS = 15 - - def __init__(self, client_no_context_takeover=False, - client_max_window_bits=None, server_no_context_takeover=False, - server_max_window_bits=None): - self.client_no_context_takeover = client_no_context_takeover - if client_max_window_bits is None: - client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS - self.client_max_window_bits = client_max_window_bits - self.server_no_context_takeover = server_no_context_takeover - if server_max_window_bits is None: - server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS - self.server_max_window_bits = server_max_window_bits - - self._compressor = None - self._decompressor = None - # This refers to the current frame - self._inbound_is_compressible = None - # This refers to the ongoing message (which might span multiple - # frames). Only the first frame in a fragmented message is flagged for - # compression, so this carries that bit forward. - self._inbound_compressed = None - - self._enabled = False - - def _compressible_opcode(self, opcode): - return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION) - - def enabled(self): - return self._enabled - - def offer(self, connection): - parameters = [ - 'client_max_window_bits=%d' % self.client_max_window_bits, - 'server_max_window_bits=%d' % self.server_max_window_bits, - ] - - if self.client_no_context_takeover: - parameters.append('client_no_context_takeover') - if self.server_no_context_takeover: - parameters.append('server_no_context_takeover') - - return '; '.join(parameters) - - def finalize(self, connection, offer): - bits = [b.strip() for b in offer.split(';')] - for bit in bits[1:]: - if bit.startswith('client_no_context_takeover'): - self.client_no_context_takeover = True - elif bit.startswith('server_no_context_takeover'): - self.server_no_context_takeover = True - elif bit.startswith('client_max_window_bits'): - self.client_max_window_bits = int(bit.split('=', 1)[1].strip()) - elif bit.startswith('server_max_window_bits'): - self.server_max_window_bits = int(bit.split('=', 1)[1].strip()) - - self._enabled = True - - def _parse_params(self, params): - client_max_window_bits = None - server_max_window_bits = None - - bits = [b.strip() for b in params.split(';')] - for bit in bits[1:]: - if bit.startswith('client_no_context_takeover'): - self.client_no_context_takeover = True - elif bit.startswith('server_no_context_takeover'): - self.server_no_context_takeover = True - elif bit.startswith('client_max_window_bits'): - if '=' in bit: - client_max_window_bits = int(bit.split('=', 1)[1].strip()) - else: - client_max_window_bits = self.client_max_window_bits - elif bit.startswith('server_max_window_bits'): - if '=' in bit: - server_max_window_bits = int(bit.split('=', 1)[1].strip()) - else: - server_max_window_bits = self.server_max_window_bits - - return client_max_window_bits, server_max_window_bits - - def accept(self, connection, offer): - client_max_window_bits, server_max_window_bits = \ - self._parse_params(offer) - - self._enabled = True - - parameters = [] - - if self.client_no_context_takeover: - parameters.append('client_no_context_takeover') - if client_max_window_bits is not None: - parameters.append('client_max_window_bits=%d' % - client_max_window_bits) - self.client_max_window_bits = client_max_window_bits - if self.server_no_context_takeover: - parameters.append('server_no_context_takeover') - if server_max_window_bits is not None: - parameters.append('server_max_window_bits=%d' % - server_max_window_bits) - self.server_max_window_bits = server_max_window_bits - - return '; '.join(parameters) - - def frame_inbound_header(self, proto, opcode, rsv, payload_length): - if rsv.rsv1 and opcode.iscontrol(): - return CloseReason.PROTOCOL_ERROR - elif rsv.rsv1 and opcode is Opcode.CONTINUATION: - return CloseReason.PROTOCOL_ERROR - - self._inbound_is_compressible = self._compressible_opcode(opcode) - - if self._inbound_compressed is None: - self._inbound_compressed = rsv.rsv1 - if self._inbound_compressed: - assert self._inbound_is_compressible - if proto.client: - bits = self.server_max_window_bits - else: - bits = self.client_max_window_bits - if self._decompressor is None: - self._decompressor = zlib.decompressobj(-int(bits)) - - return RsvBits(True, False, False) - - def frame_inbound_payload_data(self, proto, data): - if not self._inbound_compressed or not self._inbound_is_compressible: - return data - - try: - return self._decompressor.decompress(bytes(data)) - except zlib.error: - return CloseReason.INVALID_FRAME_PAYLOAD_DATA - - def frame_inbound_complete(self, proto, fin): - if not fin: - return - elif not self._inbound_is_compressible: - return - elif not self._inbound_compressed: - return - - try: - data = self._decompressor.decompress(b'\x00\x00\xff\xff') - data += self._decompressor.flush() - except zlib.error: - return CloseReason.INVALID_FRAME_PAYLOAD_DATA - - if proto.client: - no_context_takeover = self.server_no_context_takeover - else: - no_context_takeover = self.client_no_context_takeover - - if no_context_takeover: - self._decompressor = None - - self._inbound_compressed = None - - return data - - def frame_outbound(self, proto, opcode, rsv, data, fin): - if not self._compressible_opcode(opcode): - return (rsv, data) - - if opcode is not Opcode.CONTINUATION: - rsv = RsvBits(True, *rsv[1:]) - - if self._compressor is None: - assert opcode is not Opcode.CONTINUATION - if proto.client: - bits = self.client_max_window_bits - else: - bits = self.server_max_window_bits - self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, - zlib.DEFLATED, -int(bits)) - - data = self._compressor.compress(bytes(data)) - - if fin: - data += self._compressor.flush(zlib.Z_SYNC_FLUSH) - data = data[:-4] - - if proto.client: - no_context_takeover = self.client_no_context_takeover - else: - no_context_takeover = self.server_no_context_takeover - - if no_context_takeover: - self._compressor = None - - return (rsv, data) - - def __repr__(self): - descr = ['client_max_window_bits=%d' % self.client_max_window_bits] - if self.client_no_context_takeover: - descr.append('client_no_context_takeover') - descr.append('server_max_window_bits=%d' % self.server_max_window_bits) - if self.server_no_context_takeover: - descr.append('server_no_context_takeover') - - descr = '; '.join(descr) - - return '<%s %s>' % (self.__class__.__name__, descr) - - -#: SUPPORTED_EXTENSIONS maps all supported extension names to their class. -#: This can be used to iterate all supported extensions of wsproto, instantiate -#: new extensions based on their name, or check if a given extension is -#: supported or not. -SUPPORTED_EXTENSIONS = { - PerMessageDeflate.name: PerMessageDeflate -} diff --git a/mitmproxy/contrib/wsproto/frame_protocol.py b/mitmproxy/contrib/wsproto/frame_protocol.py deleted file mode 100644 index 30f146c6d..000000000 --- a/mitmproxy/contrib/wsproto/frame_protocol.py +++ /dev/null @@ -1,581 +0,0 @@ -# type: ignore - -# -*- coding: utf-8 -*- -""" -wsproto/frame_protocol -~~~~~~~~~~~~~~ - -WebSocket frame protocol implementation. -""" - -import os -import itertools -import struct -from codecs import getincrementaldecoder -from collections import namedtuple - -from enum import Enum, IntEnum - -from .compat import unicode, Utf8Validator - -try: - from wsaccel.xormask import XorMaskerSimple -except ImportError: - class XorMaskerSimple: - def __init__(self, masking_key): - self._maskbytes = itertools.cycle(bytearray(masking_key)) - - def process(self, data): - maskbytes = self._maskbytes - return bytearray(b ^ next(maskbytes) for b in bytearray(data)) - - -class XorMaskerNull: - def process(self, data): - return data - - -# RFC6455, Section 5.2 - Base Framing Protocol - -# Payload length constants -PAYLOAD_LENGTH_TWO_BYTE = 126 -PAYLOAD_LENGTH_EIGHT_BYTE = 127 -MAX_PAYLOAD_NORMAL = 125 -MAX_PAYLOAD_TWO_BYTE = 2 ** 16 - 1 -MAX_PAYLOAD_EIGHT_BYTE = 2 ** 64 - 1 -MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE - -# MASK and PAYLOAD LEN are packed into a byte -MASK_MASK = 0x80 -PAYLOAD_LEN_MASK = 0x7f - -# FIN, RSV[123] and OPCODE are packed into a single byte -FIN_MASK = 0x80 -RSV1_MASK = 0x40 -RSV2_MASK = 0x20 -RSV3_MASK = 0x10 -OPCODE_MASK = 0x0f - - -class Opcode(IntEnum): - """ - RFC 6455, Section 5.2 - Base Framing Protocol - """ - CONTINUATION = 0x0 - TEXT = 0x1 - BINARY = 0x2 - CLOSE = 0x8 - PING = 0x9 - PONG = 0xA - - def iscontrol(self): - return bool(self & 0x08) - - -class CloseReason(IntEnum): - """ - RFC 6455, Section 7.4.1 - Defined Status Codes - """ - NORMAL_CLOSURE = 1000 - GOING_AWAY = 1001 - PROTOCOL_ERROR = 1002 - UNSUPPORTED_DATA = 1003 - NO_STATUS_RCVD = 1005 - ABNORMAL_CLOSURE = 1006 - INVALID_FRAME_PAYLOAD_DATA = 1007 - POLICY_VIOLATION = 1008 - MESSAGE_TOO_BIG = 1009 - MANDATORY_EXT = 1010 - INTERNAL_ERROR = 1011 - SERVICE_RESTART = 1012 - TRY_AGAIN_LATER = 1013 - TLS_HANDSHAKE_FAILED = 1015 - - -# RFC 6455, Section 7.4.1 - Defined Status Codes -LOCAL_ONLY_CLOSE_REASONS = ( - CloseReason.NO_STATUS_RCVD, - CloseReason.ABNORMAL_CLOSURE, - CloseReason.TLS_HANDSHAKE_FAILED, -) - - -# RFC 6455, Section 7.4.2 - Status Code Ranges -MIN_CLOSE_REASON = 1000 -MIN_PROTOCOL_CLOSE_REASON = 1000 -MAX_PROTOCOL_CLOSE_REASON = 2999 -MIN_LIBRARY_CLOSE_REASON = 3000 -MAX_LIBRARY_CLOSE_REASON = 3999 -MIN_PRIVATE_CLOSE_REASON = 4000 -MAX_PRIVATE_CLOSE_REASON = 4999 -MAX_CLOSE_REASON = 4999 - - -NULL_MASK = struct.pack("!I", 0) - - -class ParseFailed(Exception): - def __init__(self, msg, code=CloseReason.PROTOCOL_ERROR): - super(ParseFailed, self).__init__(msg) - self.code = code - - -Header = namedtuple("Header", "fin rsv opcode payload_len masking_key".split()) - - -Frame = namedtuple("Frame", - "opcode payload frame_finished message_finished".split()) - - -RsvBits = namedtuple("RsvBits", "rsv1 rsv2 rsv3".split()) - - -def _truncate_utf8(data, nbytes): - if len(data) <= nbytes: - return data - - # Truncate - data = data[:nbytes] - # But we might have cut a codepoint in half, in which case we want to - # discard the partial character so the data is at least - # well-formed. This is a little inefficient since it processes the - # whole message twice when in theory we could just peek at the last - # few characters, but since this is only used for close messages (max - # length = 125 bytes) it really doesn't matter. - data = data.decode("utf-8", errors="ignore").encode("utf-8") - return data - - -class Buffer(object): - def __init__(self, initial_bytes=None): - self.buffer = bytearray() - self.bytes_used = 0 - if initial_bytes: - self.feed(initial_bytes) - - def feed(self, new_bytes): - self.buffer += new_bytes - - def consume_at_most(self, nbytes): - if not nbytes: - return bytearray() - - data = self.buffer[self.bytes_used:self.bytes_used + nbytes] - self.bytes_used += len(data) - return data - - def consume_exactly(self, nbytes): - if len(self.buffer) - self.bytes_used < nbytes: - return None - - return self.consume_at_most(nbytes) - - def commit(self): - # In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic - del self.buffer[:self.bytes_used] - self.bytes_used = 0 - - def rollback(self): - self.bytes_used = 0 - - def __len__(self): - return len(self.buffer) - - -class MessageDecoder(object): - def __init__(self): - self.opcode = None - self.validator = None - self.decoder = None - - def process_frame(self, frame): - assert not frame.opcode.iscontrol() - - if self.opcode is None: - if frame.opcode is Opcode.CONTINUATION: - raise ParseFailed("unexpected CONTINUATION") - self.opcode = frame.opcode - elif frame.opcode is not Opcode.CONTINUATION: - raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode) - - if frame.opcode is Opcode.TEXT: - self.validator = Utf8Validator() - self.decoder = getincrementaldecoder("utf-8")() - - finished = frame.frame_finished and frame.message_finished - - if self.decoder is not None: - data = self.decode_payload(frame.payload, finished) - else: - data = frame.payload - - frame = Frame(self.opcode, data, frame.frame_finished, finished) - - if finished: - self.opcode = None - self.decoder = None - - return frame - - def decode_payload(self, data, finished): - if self.validator is not None: - results = self.validator.validate(bytes(data)) - if not results[0] or (finished and not results[1]): - raise ParseFailed(u'encountered invalid UTF-8 while processing' - ' text message at payload octet index %d' % - results[3], - CloseReason.INVALID_FRAME_PAYLOAD_DATA) - - try: - return self.decoder.decode(data, finished) - except UnicodeDecodeError as exc: - raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA) - - -class FrameDecoder(object): - def __init__(self, client, extensions=None): - self.client = client - self.extensions = extensions or [] - - self.buffer = Buffer() - - self.header = None - self.effective_opcode = None - self.masker = None - self.payload_required = 0 - self.payload_consumed = 0 - - def receive_bytes(self, data): - self.buffer.feed(data) - - def process_buffer(self): - if not self.header: - if not self.parse_header(): - return None - - if len(self.buffer) < self.payload_required: - return None - - payload_remaining = self.header.payload_len - self.payload_consumed - payload = self.buffer.consume_at_most(payload_remaining) - if not payload and self.header.payload_len > 0: - return None - self.buffer.commit() - - self.payload_consumed += len(payload) - finished = self.payload_consumed == self.header.payload_len - - payload = self.masker.process(payload) - - for extension in self.extensions: - payload = extension.frame_inbound_payload_data(self, payload) - if isinstance(payload, CloseReason): - raise ParseFailed("error in extension", payload) - - if finished: - final = bytearray() - for extension in self.extensions: - result = extension.frame_inbound_complete(self, - self.header.fin) - if isinstance(result, CloseReason): - raise ParseFailed("error in extension", result) - if result is not None: - final += result - payload += final - - frame = Frame(self.effective_opcode, payload, finished, - self.header.fin) - - if finished: - self.header = None - self.effective_opcode = None - self.masker = None - else: - self.effective_opcode = Opcode.CONTINUATION - - return frame - - def parse_header(self): - data = self.buffer.consume_exactly(2) - if data is None: - self.buffer.rollback() - return False - - fin = bool(data[0] & FIN_MASK) - rsv = RsvBits(bool(data[0] & RSV1_MASK), - bool(data[0] & RSV2_MASK), - bool(data[0] & RSV3_MASK)) - opcode = data[0] & OPCODE_MASK - try: - opcode = Opcode(opcode) - except ValueError: - raise ParseFailed("Invalid opcode {:#x}".format(opcode)) - - if opcode.iscontrol() and not fin: - raise ParseFailed("Invalid attempt to fragment control frame") - - has_mask = bool(data[1] & MASK_MASK) - payload_len = data[1] & PAYLOAD_LEN_MASK - payload_len = self.parse_extended_payload_length(opcode, payload_len) - if payload_len is None: - self.buffer.rollback() - return False - - self.extension_processing(opcode, rsv, payload_len) - - if has_mask and self.client: - raise ParseFailed("client received unexpected masked frame") - if not has_mask and not self.client: - raise ParseFailed("server received unexpected unmasked frame") - if has_mask: - masking_key = self.buffer.consume_exactly(4) - if masking_key is None: - self.buffer.rollback() - return False - self.masker = XorMaskerSimple(masking_key) - else: - self.masker = XorMaskerNull() - - self.buffer.commit() - self.header = Header(fin, rsv, opcode, payload_len, None) - self.effective_opcode = self.header.opcode - if self.header.opcode.iscontrol(): - self.payload_required = payload_len - else: - self.payload_required = 0 - self.payload_consumed = 0 - return True - - def parse_extended_payload_length(self, opcode, payload_len): - if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL: - raise ParseFailed("Control frame with payload len > 125") - if payload_len == PAYLOAD_LENGTH_TWO_BYTE: - data = self.buffer.consume_exactly(2) - if data is None: - return None - (payload_len,) = struct.unpack("!H", data) - if payload_len <= MAX_PAYLOAD_NORMAL: - raise ParseFailed( - "Payload length used 2 bytes when 1 would have sufficed") - elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE: - data = self.buffer.consume_exactly(8) - if data is None: - return None - (payload_len,) = struct.unpack("!Q", data) - if payload_len <= MAX_PAYLOAD_TWO_BYTE: - raise ParseFailed( - "Payload length used 8 bytes when 2 would have sufficed") - if payload_len >> 63: - # I'm not sure why this is illegal, but that's what the RFC - # says, so... - raise ParseFailed("8-byte payload length with non-zero MSB") - - return payload_len - - def extension_processing(self, opcode, rsv, payload_len): - rsv_used = [False, False, False] - for extension in self.extensions: - result = extension.frame_inbound_header(self, opcode, rsv, - payload_len) - if isinstance(result, CloseReason): - raise ParseFailed("error in extension", result) - for bit, used in enumerate(result): - if used: - rsv_used[bit] = True - for expected, found in zip(rsv_used, rsv): - if found and not expected: - raise ParseFailed("Reserved bit set unexpectedly") - - -class FrameProtocol(object): - class State(Enum): - HEADER = 1 - PAYLOAD = 2 - FRAME_COMPLETE = 3 - FAILED = 4 - - def __init__(self, client, extensions): - self.client = client - self.extensions = [ext for ext in extensions if ext.enabled()] - - # Global state - self._frame_decoder = FrameDecoder(self.client, self.extensions) - self._message_decoder = MessageDecoder() - self._parse_more = self.parse_more_gen() - - self._outbound_opcode = None - - def _process_close(self, frame): - data = frame.payload - - if not data: - # "If this Close control frame contains no status code, _The - # WebSocket Connection Close Code_ is considered to be 1005" - data = (CloseReason.NO_STATUS_RCVD, "") - elif len(data) == 1: - raise ParseFailed("CLOSE with 1 byte payload") - else: - (code,) = struct.unpack("!H", data[:2]) - if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON: - raise ParseFailed("CLOSE with invalid code") - try: - code = CloseReason(code) - except ValueError: - pass - if code in LOCAL_ONLY_CLOSE_REASONS: - raise ParseFailed( - "remote CLOSE with local-only reason") - if not isinstance(code, CloseReason) and \ - code <= MAX_PROTOCOL_CLOSE_REASON: - raise ParseFailed( - "CLOSE with unknown reserved code") - validator = Utf8Validator() - if validator is not None: - results = validator.validate(bytes(data[2:])) - if not (results[0] and results[1]): - raise ParseFailed(u'encountered invalid UTF-8 while' - ' processing close message at payload' - ' octet index %d' % - results[3], - CloseReason.INVALID_FRAME_PAYLOAD_DATA) - try: - reason = data[2:].decode("utf-8") - except UnicodeDecodeError as exc: - raise ParseFailed( - "Error decoding CLOSE reason: " + str(exc), - CloseReason.INVALID_FRAME_PAYLOAD_DATA) - data = (code, reason) - - return Frame(frame.opcode, data, frame.frame_finished, - frame.message_finished) - - def parse_more_gen(self): - # Consume as much as we can from self._buffer, yielding events, and - # then yield None when we need more data. Or raise ParseFailed. - - # XX FIXME this should probably be refactored so that we never see - # disabled extensions in the first place... - self.extensions = [ext for ext in self.extensions if ext.enabled()] - closed = False - - while not closed: - frame = self._frame_decoder.process_buffer() - - if frame is not None: - if not frame.opcode.iscontrol(): - frame = self._message_decoder.process_frame(frame) - elif frame.opcode == Opcode.CLOSE: - frame = self._process_close(frame) - closed = True - - yield frame - - def receive_bytes(self, data): - self._frame_decoder.receive_bytes(data) - - def received_frames(self): - for event in self._parse_more: - if event is None: - break - else: - yield event - - def close(self, code=None, reason=None): - payload = bytearray() - if code is None and reason is not None: - raise TypeError("cannot specify a reason without a code") - if code in LOCAL_ONLY_CLOSE_REASONS: - code = CloseReason.NORMAL_CLOSURE - if code is not None: - payload += bytearray(struct.pack('!H', code)) - if reason is not None: - payload += _truncate_utf8(reason.encode('utf-8'), - MAX_PAYLOAD_NORMAL - 2) - - return self._serialize_frame(Opcode.CLOSE, payload) - - def ping(self, payload=b''): - return self._serialize_frame(Opcode.PING, payload) - - def pong(self, payload=b''): - return self._serialize_frame(Opcode.PONG, payload) - - def send_data(self, payload=b'', fin=True): - if isinstance(payload, (bytes, bytearray, memoryview)): - opcode = Opcode.BINARY - elif isinstance(payload, unicode): - opcode = Opcode.TEXT - payload = payload.encode('utf-8') - else: - raise ValueError('Must provide bytes or text') - - if self._outbound_opcode is None: - self._outbound_opcode = opcode - elif self._outbound_opcode is not opcode: - raise TypeError('Data type mismatch inside message') - else: - opcode = Opcode.CONTINUATION - - if fin: - self._outbound_opcode = None - - return self._serialize_frame(opcode, payload, fin) - - def _make_fin_rsv_opcode(self, fin, rsv, opcode): - fin = int(fin) << 7 - rsv = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + \ - (int(rsv.rsv3) << 4) - opcode = int(opcode) - - return fin | rsv | opcode - - def _serialize_frame(self, opcode, payload=b'', fin=True): - rsv = RsvBits(False, False, False) - for extension in reversed(self.extensions): - rsv, payload = extension.frame_outbound(self, opcode, rsv, payload, - fin) - - fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode) - - payload_length = len(payload) - quad_payload = False - if payload_length <= MAX_PAYLOAD_NORMAL: - first_payload = payload_length - second_payload = None - elif payload_length <= MAX_PAYLOAD_TWO_BYTE: - first_payload = PAYLOAD_LENGTH_TWO_BYTE - second_payload = payload_length - else: - first_payload = PAYLOAD_LENGTH_EIGHT_BYTE - second_payload = payload_length - quad_payload = True - - if self.client: - first_payload |= 1 << 7 - - header = bytearray([fin_rsv_opcode, first_payload]) - if second_payload is not None: - if opcode.iscontrol(): - raise ValueError("payload too long for control frame") - if quad_payload: - header += bytearray(struct.pack('!Q', second_payload)) - else: - header += bytearray(struct.pack('!H', second_payload)) - - if self.client: - # "The masking key is a 32-bit value chosen at random by the - # client. When preparing a masked frame, the client MUST pick a - # fresh masking key from the set of allowed 32-bit values. The - # masking key needs to be unpredictable; thus, the masking key - # MUST be derived from a strong source of entropy, and the masking - # key for a given frame MUST NOT make it simple for a server/proxy - # to predict the masking key for a subsequent frame. The - # unpredictability of the masking key is essential to prevent - # authors of malicious applications from selecting the bytes that - # appear on the wire." - # -- https://tools.ietf.org/html/rfc6455#section-5.3 - masking_key = os.urandom(4) - masker = XorMaskerSimple(masking_key) - return header + masking_key + masker.process(payload) - - return header + payload diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index 92f99518d..2d8458a5b 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -2,10 +2,10 @@ import socket from OpenSSL import SSL -from mitmproxy.contrib import wsproto -from mitmproxy.contrib.wsproto import events -from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection -from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate +import wsproto +from wsproto import events +from wsproto.connection import ConnectionType, WSConnection +from wsproto.extensions import PerMessageDeflate from mitmproxy import exceptions from mitmproxy import flow diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index a37edb54a..662578523 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -1,7 +1,8 @@ import time from typing import List, Optional -from mitmproxy.contrib import wsproto +from wsproto.frame_protocol import CloseReason +from wsproto.frame_protocol import Opcode from mitmproxy import flow from mitmproxy.net import websockets @@ -17,7 +18,7 @@ class WebSocketMessage(serializable.Serializable): def __init__( self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None, killed: bool=False ) -> None: - self.type = wsproto.frame_protocol.Opcode(type) # type: ignore + self.type = Opcode(type) # type: ignore """indicates either TEXT or BINARY (from wsproto.frame_protocol.Opcode).""" self.from_client = from_client """True if this messages was sent by the client.""" @@ -37,10 +38,10 @@ class WebSocketMessage(serializable.Serializable): def set_state(self, state): self.type, self.from_client, self.content, self.timestamp, self.killed = state - self.type = wsproto.frame_protocol.Opcode(self.type) # replace enum with bare int + self.type = Opcode(self.type) # replace enum with bare int def __repr__(self): - if self.type == wsproto.frame_protocol.Opcode.TEXT: + if self.type == Opcode.TEXT: return "text message: {}".format(repr(self.content)) else: return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) @@ -66,7 +67,7 @@ class WebSocketFlow(flow.Flow): """A list containing all WebSocketMessage's.""" self.close_sender = 'client' """'client' if the client initiated connection closing.""" - self.close_code = wsproto.frame_protocol.CloseReason.NORMAL_CLOSURE + self.close_code = CloseReason.NORMAL_CLOSURE """WebSocket close code.""" self.close_message = '(message missing)' """WebSocket close message.""" diff --git a/setup.py b/setup.py index 4ae1974bf..c66d1382e 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,7 @@ setup( "sortedcontainers>=1.5.4, <1.6", "tornado>=4.3, <4.6", "urwid>=1.3.1, <1.4", + "wsproto>=0.11.0,<0.12.0", ], extras_require={ ':sys_platform == "win32"': [