vendoring of wsproto

https://github.com/python-hyper/wsproto.git
commit 5ea2da61266796666f5de6461aaae22e6b00deba
This commit is contained in:
Thomas Kriechbaumer 2017-12-12 21:47:24 +01:00
parent 70e1409261
commit f5fafbfcb5
7 changed files with 1419 additions and 4 deletions

View File

@ -0,0 +1,20 @@
# flake8: noqa
import sys
PY2 = sys.version_info.major == 2
PY3 = sys.version_info.major == 3
if PY3:
unicode = str
def Utf8Validator():
return None
else:
unicode = unicode
try:
from wsaccel.utf8validator import Utf8Validator
except ImportError:
from .utf8validator import Utf8Validator

View File

@ -0,0 +1,477 @@
# -*- coding: utf-8 -*-
"""
wsproto/connection
~~~~~~~~~~~~~~
An implementation of a WebSocket connection.
"""
import os
import base64
import hashlib
from collections import deque
from enum import Enum
import h11
from .events import (
ConnectionRequested, ConnectionEstablished, ConnectionClosed,
ConnectionFailed, TextReceived, BytesReceived, PingReceived, PongReceived
)
from .frame_protocol import FrameProtocol, ParseFailed, CloseReason, Opcode
# RFC6455, Section 1.3 - Opening Handshake
ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
class ConnectionState(Enum):
"""
RFC 6455, Section 4 - Opening Handshake
"""
CONNECTING = 0
OPEN = 1
CLOSING = 2
CLOSED = 3
class ConnectionType(Enum):
CLIENT = 1
SERVER = 2
CLIENT = ConnectionType.CLIENT
SERVER = ConnectionType.SERVER
# Some convenience utilities for working with HTTP headers
def _normed_header_dict(h11_headers):
# This mangles Set-Cookie headers. But it happens that we don't care about
# any of those, so it's OK. For every other HTTP header, if there are
# multiple instances then you're allowed to join them together with
# commas.
name_to_values = {}
for name, value in h11_headers:
name_to_values.setdefault(name, []).append(value)
name_to_normed_value = {}
for name, values in name_to_values.items():
name_to_normed_value[name] = b", ".join(values)
return name_to_normed_value
# We use this for parsing the proposed protocol list, and for parsing the
# proposed and accepted extension lists. For the proposed protocol list it's
# fine, because the ABNF is just 1#token. But for the extension lists, it's
# wrong, because those can contain quoted strings, which can in turn contain
# commas. XX FIXME
def _split_comma_header(value):
return [piece.decode('ascii').strip() for piece in value.split(b',')]
class WSConnection(object):
"""
A low-level WebSocket connection object.
This wraps two other protocol objects, an HTTP/1.1 protocol object used
to do the initial HTTP upgrade handshake and a WebSocket frame protocol
object used to exchange messages and other control frames.
:param conn_type: Whether this object is on the client- or server-side of
a connection. To initialise as a client pass ``CLIENT`` otherwise
pass ``SERVER``.
:type conn_type: ``ConnectionType``
:param host: The hostname to pass to the server when acting as a client.
:type host: ``str``
:param resource: The resource (aka path) to pass to the server when acting
as a client.
:type resource: ``str``
:param extensions: A list of extensions to use on this connection.
Extensions should be instances of a subclass of
:class:`Extension <wsproto.extensions.Extension>`.
:param subprotocols: A list of subprotocols to request when acting as a
client, ordered by preference. This has no impact on the connection
itself.
:type subprotocol: ``list`` of ``str``
"""
def __init__(self, conn_type, host=None, resource=None, extensions=None,
subprotocols=None):
self.client = conn_type is ConnectionType.CLIENT
self.host = host
self.resource = resource
self.subprotocols = subprotocols or []
self.extensions = extensions or []
self.version = b'13'
self._state = ConnectionState.CONNECTING
self._close_reason = None
self._nonce = None
self._outgoing = b''
self._events = deque()
self._proto = None
if self.client:
self._upgrade_connection = h11.Connection(h11.CLIENT)
else:
self._upgrade_connection = h11.Connection(h11.SERVER)
if self.client:
if self.host is None:
raise ValueError(
"Host must not be None for a client-side connection.")
if self.resource is None:
raise ValueError(
"Resource must not be None for a client-side connection.")
self.initiate_connection()
def initiate_connection(self):
self._generate_nonce()
headers = {
b"Host": self.host.encode('ascii'),
b"Upgrade": b'WebSocket',
b"Connection": b'Upgrade',
b"Sec-WebSocket-Key": self._nonce,
b"Sec-WebSocket-Version": self.version,
}
if self.subprotocols:
headers[b"Sec-WebSocket-Protocol"] = ", ".join(self.subprotocols)
if self.extensions:
offers = {e.name: e.offer(self) for e in self.extensions}
extensions = []
for name, params in offers.items():
if params is True:
extensions.append(name.encode('ascii'))
elif params:
# py34 annoyance: doesn't support bytestring formatting
extensions.append(('%s; %s' % (name, params))
.encode("ascii"))
if extensions:
headers[b'Sec-WebSocket-Extensions'] = b', '.join(extensions)
upgrade = h11.Request(method=b'GET', target=self.resource,
headers=headers.items())
self._outgoing += self._upgrade_connection.send(upgrade)
def send_data(self, payload, final=True):
"""
Send a message or part of a message to the remote peer.
If ``final`` is ``False`` it indicates that this is part of a longer
message. If ``final`` is ``True`` it indicates that this is either a
self-contained message or the last part of a longer message.
If ``payload`` is of type ``bytes`` then the message is flagged as
being binary If it is of type ``str`` encoded as UTF-8 and sent as
text.
:param payload: The message body to send.
:type payload: ``bytes`` or ``str``
:param final: Whether there are more parts to this message to be sent.
:type final: ``bool``
"""
self._outgoing += self._proto.send_data(payload, final)
def close(self, code=CloseReason.NORMAL_CLOSURE, reason=None):
self._outgoing += self._proto.close(code, reason)
self._state = ConnectionState.CLOSING
@property
def closed(self):
return self._state is ConnectionState.CLOSED
def bytes_to_send(self, amount=None):
"""
Return any data that is to be sent to the remote peer.
:param amount: (optional) The maximum number of bytes to be provided.
If ``None`` or not provided it will return all available bytes.
:type amount: ``int``
"""
if amount is None:
data = self._outgoing
self._outgoing = b''
else:
data = self._outgoing[:amount]
self._outgoing = self._outgoing[amount:]
return data
def receive_bytes(self, data):
"""
Pass some received bytes to the connection for processing.
:param data: The data received from the remote peer.
:type data: ``bytes``
"""
if data is None and self._state is ConnectionState.OPEN:
# "If _The WebSocket Connection is Closed_ and no Close control
# frame was received by the endpoint (such as could occur if the
# underlying transport connection is lost), _The WebSocket
# Connection Close Code_ is considered to be 1006."
self._events.append(ConnectionClosed(CloseReason.ABNORMAL_CLOSURE))
self._state = ConnectionState.CLOSED
return
elif data is None:
self._state = ConnectionState.CLOSED
return
if self._state is ConnectionState.CONNECTING:
event, data = self._process_upgrade(data)
if event is not None:
self._events.append(event)
if self._state is ConnectionState.OPEN:
self._proto.receive_bytes(data)
def _process_upgrade(self, data):
self._upgrade_connection.receive_data(data)
while True:
try:
event = self._upgrade_connection.next_event()
except h11.RemoteProtocolError:
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Bad HTTP message"), b''
if event is h11.NEED_DATA:
break
elif self.client and isinstance(event, (h11.InformationalResponse,
h11.Response)):
data = self._upgrade_connection.trailing_data[0]
return self._establish_client_connection(event), data
elif not self.client and isinstance(event, h11.Request):
return self._process_connection_request(event), None
else:
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Bad HTTP message"), b''
self._incoming = b''
return None, None
def events(self):
"""
Return a generator that provides any events that have been generated
by protocol activity.
:returns: generator
"""
while self._events:
yield self._events.popleft()
if self._proto is None:
return
try:
for frame in self._proto.received_frames():
if frame.opcode is Opcode.PING:
assert frame.frame_finished and frame.message_finished
self._outgoing += self._proto.pong(frame.payload)
yield PingReceived(frame.payload)
elif frame.opcode is Opcode.PONG:
assert frame.frame_finished and frame.message_finished
yield PongReceived(frame.payload)
elif frame.opcode is Opcode.CLOSE:
code, reason = frame.payload
self.close(code, reason)
yield ConnectionClosed(code, reason)
elif frame.opcode is Opcode.TEXT:
yield TextReceived(frame.payload,
frame.frame_finished,
frame.message_finished)
elif frame.opcode is Opcode.BINARY:
yield BytesReceived(frame.payload,
frame.frame_finished,
frame.message_finished)
except ParseFailed as exc:
# XX FIXME: apparently autobahn intentionally deviates from the
# spec in that on protocol errors it just closes the connection
# rather than trying to send a CLOSE frame. Investigate whether we
# should do the same.
self.close(code=exc.code, reason=str(exc))
yield ConnectionClosed(exc.code, reason=str(exc))
def _generate_nonce(self):
# os.urandom may be overkill for this use case, but I don't think this
# is a bottleneck, and better safe than sorry...
self._nonce = base64.b64encode(os.urandom(16))
def _generate_accept_token(self, token):
accept_token = token + ACCEPT_GUID
accept_token = hashlib.sha1(accept_token).digest()
return base64.b64encode(accept_token)
def _establish_client_connection(self, event):
if event.status_code != 101:
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Bad status code from server")
headers = _normed_header_dict(event.headers)
if headers[b'connection'].lower() != b'upgrade':
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Missing Connection: Upgrade header")
if headers[b'upgrade'].lower() != b'websocket':
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Missing Upgrade: WebSocket header")
accept_token = self._generate_accept_token(self._nonce)
if headers[b'sec-websocket-accept'] != accept_token:
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Bad accept token")
subprotocol = headers.get(b'sec-websocket-protocol', None)
if subprotocol is not None:
subprotocol = subprotocol.decode('ascii')
if subprotocol not in self.subprotocols:
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"unrecognized subprotocol {!r}"
.format(subprotocol))
extensions = headers.get(b'sec-websocket-extensions', None)
if extensions:
accepts = _split_comma_header(extensions)
for accept in accepts:
name = accept.split(';', 1)[0].strip()
for extension in self.extensions:
if extension.name == name:
extension.finalize(self, accept)
break
else:
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"unrecognized extension {!r}"
.format(name))
self._proto = FrameProtocol(self.client, self.extensions)
self._state = ConnectionState.OPEN
return ConnectionEstablished(subprotocol, extensions)
def _process_connection_request(self, event):
if event.method != b'GET':
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Request method must be GET")
headers = _normed_header_dict(event.headers)
if headers[b'connection'].lower() != b'upgrade':
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Missing Connection: Upgrade header")
if headers[b'upgrade'].lower() != b'websocket':
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Missing Upgrade: WebSocket header")
if b'sec-websocket-version' not in headers:
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Missing Sec-WebSocket-Version header")
# XX FIXME: need to check Sec-Websocket-Version, and respond with a
# 400 if it's not what we expect
if b'sec-websocket-protocol' in headers:
proposed_subprotocols = _split_comma_header(
headers[b'sec-websocket-protocol'])
else:
proposed_subprotocols = []
if b'sec-websocket-key' not in headers:
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
"Missing Sec-WebSocket-Key header")
return ConnectionRequested(proposed_subprotocols, event)
def _extension_accept(self, extensions_header):
accepts = {}
offers = _split_comma_header(extensions_header)
for offer in offers:
name = offer.split(';', 1)[0].strip()
for extension in self.extensions:
if extension.name == name:
accept = extension.accept(self, offer)
if accept is True:
accepts[extension.name] = True
elif accept is not False and accept is not None:
accepts[extension.name] = accept.encode('ascii')
if accepts:
extensions = []
for name, params in accepts.items():
if params is True:
extensions.append(name.encode('ascii'))
else:
# py34 annoyance: doesn't support bytestring formatting
params = params.decode("ascii")
extensions.append(('%s; %s' % (name, params))
.encode("ascii"))
return b', '.join(extensions)
return None
def accept(self, event, subprotocol=None):
request = event.h11request
request_headers = _normed_header_dict(request.headers)
nonce = request_headers[b'sec-websocket-key']
accept_token = self._generate_accept_token(nonce)
headers = {
b"Upgrade": b'WebSocket',
b"Connection": b'Upgrade',
b"Sec-WebSocket-Accept": accept_token,
}
if subprotocol is not None:
if subprotocol not in event.proposed_subprotocols:
raise ValueError(
"unexpected subprotocol {!r}".format(subprotocol))
headers[b'Sec-WebSocket-Protocol'] = subprotocol
extensions = request_headers.get(b'sec-websocket-extensions', None)
if extensions:
accepts = self._extension_accept(extensions)
if accepts:
headers[b"Sec-WebSocket-Extensions"] = accepts
response = h11.InformationalResponse(status_code=101,
headers=headers.items())
self._outgoing += self._upgrade_connection.send(response)
self._proto = FrameProtocol(self.client, self.extensions)
self._state = ConnectionState.OPEN
def ping(self, payload=None):
"""
Send a PING message to the peer.
:param payload: an optional payload to send with the message
"""
payload = bytes(payload or b'')
self._outgoing += self._proto.ping(payload)
def pong(self, payload=None):
"""
Send a PONG message to the peer.
This method can be used to send an unsolicted PONG to the peer.
It is not needed otherwise since every received PING causes a
corresponding PONG to be sent automatically.
:param payload: an optional payload to send with the message
"""
payload = bytes(payload or b'')
self._outgoing += self._proto.pong(payload)

View File

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
"""
wsproto/events
~~~~~~~~~~
Events that result from processing data on a WebSocket connection.
"""
class ConnectionRequested(object):
def __init__(self, proposed_subprotocols, h11request):
self.proposed_subprotocols = proposed_subprotocols
self.h11request = h11request
def __repr__(self):
path = self.h11request.target
headers = dict(self.h11request.headers)
host = headers[b'host']
version = headers[b'sec-websocket-version']
subprotocol = headers.get(b'sec-websocket-protocol', None)
extensions = []
fmt = '<%s host=%s path=%s version=%s subprotocol=%r extensions=%r>'
return fmt % (self.__class__.__name__, host, path, version,
subprotocol, extensions)
class ConnectionEstablished(object):
def __init__(self, subprotocol=None, extensions=None):
self.subprotocol = subprotocol
self.extensions = extensions
if self.extensions is None:
self.extensions = []
def __repr__(self):
return '<ConnectionEstablished subprotocol=%r extensions=%r>' % \
(self.subprotocol, self.extensions)
class ConnectionClosed(object):
def __init__(self, code, reason=None):
self.code = code
self.reason = reason
def __repr__(self):
return '<%s code=%r reason="%s">' % (self.__class__.__name__,
self.code, self.reason)
class ConnectionFailed(ConnectionClosed):
pass
class DataReceived(object):
def __init__(self, data, frame_finished, message_finished):
self.data = data
# This has no semantic content, but is provided just in case some
# weird edge case user wants to be able to reconstruct the
# fragmentation pattern of the original stream. You don't want it:
self.frame_finished = frame_finished
# This is the field that you almost certainly want:
self.message_finished = message_finished
class TextReceived(DataReceived):
pass
class BytesReceived(DataReceived):
pass
class PingReceived(object):
def __init__(self, payload):
self.payload = payload
class PongReceived(object):
def __init__(self, payload):
self.payload = payload

View File

@ -0,0 +1,257 @@
# -*- coding: utf-8 -*-
"""
wsproto/extensions
~~~~~~~~~~~~~~
WebSocket extensions.
"""
import zlib
from .frame_protocol import CloseReason, Opcode, RsvBits
class Extension(object):
name = None
def enabled(self):
return False
def offer(self, connection):
pass
def accept(self, connection, offer):
pass
def finalize(self, connection, offer):
pass
def frame_inbound_header(self, proto, opcode, rsv, payload_length):
return RsvBits(False, False, False)
def frame_inbound_payload_data(self, proto, data):
return data
def frame_inbound_complete(self, proto, fin):
pass
def frame_outbound(self, proto, opcode, rsv, data, fin):
return (rsv, data)
class PerMessageDeflate(Extension):
name = 'permessage-deflate'
DEFAULT_CLIENT_MAX_WINDOW_BITS = 15
DEFAULT_SERVER_MAX_WINDOW_BITS = 15
def __init__(self, client_no_context_takeover=False,
client_max_window_bits=None, server_no_context_takeover=False,
server_max_window_bits=None):
self.client_no_context_takeover = client_no_context_takeover
if client_max_window_bits is None:
client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS
self.client_max_window_bits = client_max_window_bits
self.server_no_context_takeover = server_no_context_takeover
if server_max_window_bits is None:
server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS
self.server_max_window_bits = server_max_window_bits
self._compressor = None
self._decompressor = None
# This refers to the current frame
self._inbound_is_compressible = None
# This refers to the ongoing message (which might span multiple
# frames). Only the first frame in a fragmented message is flagged for
# compression, so this carries that bit forward.
self._inbound_compressed = None
self._enabled = False
def _compressible_opcode(self, opcode):
return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION)
def enabled(self):
return self._enabled
def offer(self, connection):
parameters = [
'client_max_window_bits=%d' % self.client_max_window_bits,
'server_max_window_bits=%d' % self.server_max_window_bits,
]
if self.client_no_context_takeover:
parameters.append('client_no_context_takeover')
if self.server_no_context_takeover:
parameters.append('server_no_context_takeover')
return '; '.join(parameters)
def finalize(self, connection, offer):
bits = [b.strip() for b in offer.split(';')]
for bit in bits[1:]:
if bit.startswith('client_no_context_takeover'):
self.client_no_context_takeover = True
elif bit.startswith('server_no_context_takeover'):
self.server_no_context_takeover = True
elif bit.startswith('client_max_window_bits'):
self.client_max_window_bits = int(bit.split('=', 1)[1].strip())
elif bit.startswith('server_max_window_bits'):
self.server_max_window_bits = int(bit.split('=', 1)[1].strip())
self._enabled = True
def _parse_params(self, params):
client_max_window_bits = None
server_max_window_bits = None
bits = [b.strip() for b in params.split(';')]
for bit in bits[1:]:
if bit.startswith('client_no_context_takeover'):
self.client_no_context_takeover = True
elif bit.startswith('server_no_context_takeover'):
self.server_no_context_takeover = True
elif bit.startswith('client_max_window_bits'):
if '=' in bit:
client_max_window_bits = int(bit.split('=', 1)[1].strip())
else:
client_max_window_bits = self.client_max_window_bits
elif bit.startswith('server_max_window_bits'):
if '=' in bit:
server_max_window_bits = int(bit.split('=', 1)[1].strip())
else:
server_max_window_bits = self.server_max_window_bits
return client_max_window_bits, server_max_window_bits
def accept(self, connection, offer):
client_max_window_bits, server_max_window_bits = \
self._parse_params(offer)
self._enabled = True
parameters = []
if self.client_no_context_takeover:
parameters.append('client_no_context_takeover')
if client_max_window_bits is not None:
parameters.append('client_max_window_bits=%d' %
client_max_window_bits)
self.client_max_window_bits = client_max_window_bits
if self.server_no_context_takeover:
parameters.append('server_no_context_takeover')
if server_max_window_bits is not None:
parameters.append('server_max_window_bits=%d' %
server_max_window_bits)
self.server_max_window_bits = server_max_window_bits
return '; '.join(parameters)
def frame_inbound_header(self, proto, opcode, rsv, payload_length):
if rsv.rsv1 and opcode.iscontrol():
return CloseReason.PROTOCOL_ERROR
elif rsv.rsv1 and opcode is Opcode.CONTINUATION:
return CloseReason.PROTOCOL_ERROR
self._inbound_is_compressible = self._compressible_opcode(opcode)
if self._inbound_compressed is None:
self._inbound_compressed = rsv.rsv1
if self._inbound_compressed:
assert self._inbound_is_compressible
if proto.client:
bits = self.server_max_window_bits
else:
bits = self.client_max_window_bits
if self._decompressor is None:
self._decompressor = zlib.decompressobj(-int(bits))
return RsvBits(True, False, False)
def frame_inbound_payload_data(self, proto, data):
if not self._inbound_compressed or not self._inbound_is_compressible:
return data
try:
return self._decompressor.decompress(bytes(data))
except zlib.error:
return CloseReason.INVALID_FRAME_PAYLOAD_DATA
def frame_inbound_complete(self, proto, fin):
if not fin:
return
elif not self._inbound_is_compressible:
return
elif not self._inbound_compressed:
return
try:
data = self._decompressor.decompress(b'\x00\x00\xff\xff')
data += self._decompressor.flush()
except zlib.error:
return CloseReason.INVALID_FRAME_PAYLOAD_DATA
if proto.client:
no_context_takeover = self.server_no_context_takeover
else:
no_context_takeover = self.client_no_context_takeover
if no_context_takeover:
self._decompressor = None
self._inbound_compressed = None
return data
def frame_outbound(self, proto, opcode, rsv, data, fin):
if not self._compressible_opcode(opcode):
return (rsv, data)
if opcode is not Opcode.CONTINUATION:
rsv = RsvBits(True, *rsv[1:])
if self._compressor is None:
assert opcode is not Opcode.CONTINUATION
if proto.client:
bits = self.client_max_window_bits
else:
bits = self.server_max_window_bits
self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
zlib.DEFLATED, -int(bits))
data = self._compressor.compress(bytes(data))
if fin:
data += self._compressor.flush(zlib.Z_SYNC_FLUSH)
data = data[:-4]
if proto.client:
no_context_takeover = self.client_no_context_takeover
else:
no_context_takeover = self.server_no_context_takeover
if no_context_takeover:
self._compressor = None
return (rsv, data)
def __repr__(self):
descr = ['client_max_window_bits=%d' % self.client_max_window_bits]
if self.client_no_context_takeover:
descr.append('client_no_context_takeover')
descr.append('server_max_window_bits=%d' % self.server_max_window_bits)
if self.server_no_context_takeover:
descr.append('server_no_context_takeover')
descr = '; '.join(descr)
return '<%s %s>' % (self.__class__.__name__, descr)
#: SUPPORTED_EXTENSIONS maps all supported extension names to their class.
#: This can be used to iterate all supported extensions of wsproto, instantiate
#: new extensions based on their name, or check if a given extension is
#: supported or not.
SUPPORTED_EXTENSIONS = {
PerMessageDeflate.name: PerMessageDeflate
}

View File

@ -0,0 +1,579 @@
# -*- coding: utf-8 -*-
"""
wsproto/frame_protocol
~~~~~~~~~~~~~~
WebSocket frame protocol implementation.
"""
import os
import itertools
import struct
from codecs import getincrementaldecoder
from collections import namedtuple
from enum import Enum, IntEnum
from .compat import unicode, Utf8Validator
try:
from wsaccel.xormask import XorMaskerSimple
except ImportError:
class XorMaskerSimple:
def __init__(self, masking_key):
self._maskbytes = itertools.cycle(bytearray(masking_key))
def process(self, data):
maskbytes = self._maskbytes
return bytearray(b ^ next(maskbytes) for b in bytearray(data))
class XorMaskerNull:
def process(self, data):
return data
# RFC6455, Section 5.2 - Base Framing Protocol
# Payload length constants
PAYLOAD_LENGTH_TWO_BYTE = 126
PAYLOAD_LENGTH_EIGHT_BYTE = 127
MAX_PAYLOAD_NORMAL = 125
MAX_PAYLOAD_TWO_BYTE = 2 ** 16 - 1
MAX_PAYLOAD_EIGHT_BYTE = 2 ** 64 - 1
MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE
# MASK and PAYLOAD LEN are packed into a byte
MASK_MASK = 0x80
PAYLOAD_LEN_MASK = 0x7f
# FIN, RSV[123] and OPCODE are packed into a single byte
FIN_MASK = 0x80
RSV1_MASK = 0x40
RSV2_MASK = 0x20
RSV3_MASK = 0x10
OPCODE_MASK = 0x0f
class Opcode(IntEnum):
"""
RFC 6455, Section 5.2 - Base Framing Protocol
"""
CONTINUATION = 0x0
TEXT = 0x1
BINARY = 0x2
CLOSE = 0x8
PING = 0x9
PONG = 0xA
def iscontrol(self):
return bool(self & 0x08)
class CloseReason(IntEnum):
"""
RFC 6455, Section 7.4.1 - Defined Status Codes
"""
NORMAL_CLOSURE = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
NO_STATUS_RCVD = 1005
ABNORMAL_CLOSURE = 1006
INVALID_FRAME_PAYLOAD_DATA = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXT = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
TLS_HANDSHAKE_FAILED = 1015
# RFC 6455, Section 7.4.1 - Defined Status Codes
LOCAL_ONLY_CLOSE_REASONS = (
CloseReason.NO_STATUS_RCVD,
CloseReason.ABNORMAL_CLOSURE,
CloseReason.TLS_HANDSHAKE_FAILED,
)
# RFC 6455, Section 7.4.2 - Status Code Ranges
MIN_CLOSE_REASON = 1000
MIN_PROTOCOL_CLOSE_REASON = 1000
MAX_PROTOCOL_CLOSE_REASON = 2999
MIN_LIBRARY_CLOSE_REASON = 3000
MAX_LIBRARY_CLOSE_REASON = 3999
MIN_PRIVATE_CLOSE_REASON = 4000
MAX_PRIVATE_CLOSE_REASON = 4999
MAX_CLOSE_REASON = 4999
NULL_MASK = struct.pack("!I", 0)
class ParseFailed(Exception):
def __init__(self, msg, code=CloseReason.PROTOCOL_ERROR):
super(ParseFailed, self).__init__(msg)
self.code = code
Header = namedtuple("Header", "fin rsv opcode payload_len masking_key".split())
Frame = namedtuple("Frame",
"opcode payload frame_finished message_finished".split())
RsvBits = namedtuple("RsvBits", "rsv1 rsv2 rsv3".split())
def _truncate_utf8(data, nbytes):
if len(data) <= nbytes:
return data
# Truncate
data = data[:nbytes]
# But we might have cut a codepoint in half, in which case we want to
# discard the partial character so the data is at least
# well-formed. This is a little inefficient since it processes the
# whole message twice when in theory we could just peek at the last
# few characters, but since this is only used for close messages (max
# length = 125 bytes) it really doesn't matter.
data = data.decode("utf-8", errors="ignore").encode("utf-8")
return data
class Buffer(object):
def __init__(self, initial_bytes=None):
self.buffer = bytearray()
self.bytes_used = 0
if initial_bytes:
self.feed(initial_bytes)
def feed(self, new_bytes):
self.buffer += new_bytes
def consume_at_most(self, nbytes):
if not nbytes:
return bytearray()
data = self.buffer[self.bytes_used:self.bytes_used + nbytes]
self.bytes_used += len(data)
return data
def consume_exactly(self, nbytes):
if len(self.buffer) - self.bytes_used < nbytes:
return None
return self.consume_at_most(nbytes)
def commit(self):
# In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic
del self.buffer[:self.bytes_used]
self.bytes_used = 0
def rollback(self):
self.bytes_used = 0
def __len__(self):
return len(self.buffer)
class MessageDecoder(object):
def __init__(self):
self.opcode = None
self.validator = None
self.decoder = None
def process_frame(self, frame):
assert not frame.opcode.iscontrol()
if self.opcode is None:
if frame.opcode is Opcode.CONTINUATION:
raise ParseFailed("unexpected CONTINUATION")
self.opcode = frame.opcode
elif frame.opcode is not Opcode.CONTINUATION:
raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode)
if frame.opcode is Opcode.TEXT:
self.validator = Utf8Validator()
self.decoder = getincrementaldecoder("utf-8")()
finished = frame.frame_finished and frame.message_finished
if self.decoder is not None:
data = self.decode_payload(frame.payload, finished)
else:
data = frame.payload
frame = Frame(self.opcode, data, frame.frame_finished, finished)
if finished:
self.opcode = None
self.decoder = None
return frame
def decode_payload(self, data, finished):
if self.validator is not None:
results = self.validator.validate(bytes(data))
if not results[0] or (finished and not results[1]):
raise ParseFailed(u'encountered invalid UTF-8 while processing'
' text message at payload octet index %d' %
results[3],
CloseReason.INVALID_FRAME_PAYLOAD_DATA)
try:
return self.decoder.decode(data, finished)
except UnicodeDecodeError as exc:
raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA)
class FrameDecoder(object):
def __init__(self, client, extensions=None):
self.client = client
self.extensions = extensions or []
self.buffer = Buffer()
self.header = None
self.effective_opcode = None
self.masker = None
self.payload_required = 0
self.payload_consumed = 0
def receive_bytes(self, data):
self.buffer.feed(data)
def process_buffer(self):
if not self.header:
if not self.parse_header():
return None
if len(self.buffer) < self.payload_required:
return None
payload_remaining = self.header.payload_len - self.payload_consumed
payload = self.buffer.consume_at_most(payload_remaining)
if not payload and self.header.payload_len > 0:
return None
self.buffer.commit()
self.payload_consumed += len(payload)
finished = self.payload_consumed == self.header.payload_len
payload = self.masker.process(payload)
for extension in self.extensions:
payload = extension.frame_inbound_payload_data(self, payload)
if isinstance(payload, CloseReason):
raise ParseFailed("error in extension", payload)
if finished:
final = bytearray()
for extension in self.extensions:
result = extension.frame_inbound_complete(self,
self.header.fin)
if isinstance(result, CloseReason):
raise ParseFailed("error in extension", result)
if result is not None:
final += result
payload += final
frame = Frame(self.effective_opcode, payload, finished,
self.header.fin)
if finished:
self.header = None
self.effective_opcode = None
self.masker = None
else:
self.effective_opcode = Opcode.CONTINUATION
return frame
def parse_header(self):
data = self.buffer.consume_exactly(2)
if data is None:
self.buffer.rollback()
return False
fin = bool(data[0] & FIN_MASK)
rsv = RsvBits(bool(data[0] & RSV1_MASK),
bool(data[0] & RSV2_MASK),
bool(data[0] & RSV3_MASK))
opcode = data[0] & OPCODE_MASK
try:
opcode = Opcode(opcode)
except ValueError:
raise ParseFailed("Invalid opcode {:#x}".format(opcode))
if opcode.iscontrol() and not fin:
raise ParseFailed("Invalid attempt to fragment control frame")
has_mask = bool(data[1] & MASK_MASK)
payload_len = data[1] & PAYLOAD_LEN_MASK
payload_len = self.parse_extended_payload_length(opcode, payload_len)
if payload_len is None:
self.buffer.rollback()
return False
self.extension_processing(opcode, rsv, payload_len)
if has_mask and self.client:
raise ParseFailed("client received unexpected masked frame")
if not has_mask and not self.client:
raise ParseFailed("server received unexpected unmasked frame")
if has_mask:
masking_key = self.buffer.consume_exactly(4)
if masking_key is None:
self.buffer.rollback()
return False
self.masker = XorMaskerSimple(masking_key)
else:
self.masker = XorMaskerNull()
self.buffer.commit()
self.header = Header(fin, rsv, opcode, payload_len, None)
self.effective_opcode = self.header.opcode
if self.header.opcode.iscontrol():
self.payload_required = payload_len
else:
self.payload_required = 0
self.payload_consumed = 0
return True
def parse_extended_payload_length(self, opcode, payload_len):
if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL:
raise ParseFailed("Control frame with payload len > 125")
if payload_len == PAYLOAD_LENGTH_TWO_BYTE:
data = self.buffer.consume_exactly(2)
if data is None:
return None
(payload_len,) = struct.unpack("!H", data)
if payload_len <= MAX_PAYLOAD_NORMAL:
raise ParseFailed(
"Payload length used 2 bytes when 1 would have sufficed")
elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE:
data = self.buffer.consume_exactly(8)
if data is None:
return None
(payload_len,) = struct.unpack("!Q", data)
if payload_len <= MAX_PAYLOAD_TWO_BYTE:
raise ParseFailed(
"Payload length used 8 bytes when 2 would have sufficed")
if payload_len >> 63:
# I'm not sure why this is illegal, but that's what the RFC
# says, so...
raise ParseFailed("8-byte payload length with non-zero MSB")
return payload_len
def extension_processing(self, opcode, rsv, payload_len):
rsv_used = [False, False, False]
for extension in self.extensions:
result = extension.frame_inbound_header(self, opcode, rsv,
payload_len)
if isinstance(result, CloseReason):
raise ParseFailed("error in extension", result)
for bit, used in enumerate(result):
if used:
rsv_used[bit] = True
for expected, found in zip(rsv_used, rsv):
if found and not expected:
raise ParseFailed("Reserved bit set unexpectedly")
class FrameProtocol(object):
class State(Enum):
HEADER = 1
PAYLOAD = 2
FRAME_COMPLETE = 3
FAILED = 4
def __init__(self, client, extensions):
self.client = client
self.extensions = [ext for ext in extensions if ext.enabled()]
# Global state
self._frame_decoder = FrameDecoder(self.client, self.extensions)
self._message_decoder = MessageDecoder()
self._parse_more = self.parse_more_gen()
self._outbound_opcode = None
def _process_close(self, frame):
data = frame.payload
if not data:
# "If this Close control frame contains no status code, _The
# WebSocket Connection Close Code_ is considered to be 1005"
data = (CloseReason.NO_STATUS_RCVD, "")
elif len(data) == 1:
raise ParseFailed("CLOSE with 1 byte payload")
else:
(code,) = struct.unpack("!H", data[:2])
if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON:
raise ParseFailed("CLOSE with invalid code")
try:
code = CloseReason(code)
except ValueError:
pass
if code in LOCAL_ONLY_CLOSE_REASONS:
raise ParseFailed(
"remote CLOSE with local-only reason")
if not isinstance(code, CloseReason) and \
code <= MAX_PROTOCOL_CLOSE_REASON:
raise ParseFailed(
"CLOSE with unknown reserved code")
validator = Utf8Validator()
if validator is not None:
results = validator.validate(bytes(data[2:]))
if not (results[0] and results[1]):
raise ParseFailed(u'encountered invalid UTF-8 while'
' processing close message at payload'
' octet index %d' %
results[3],
CloseReason.INVALID_FRAME_PAYLOAD_DATA)
try:
reason = data[2:].decode("utf-8")
except UnicodeDecodeError as exc:
raise ParseFailed(
"Error decoding CLOSE reason: " + str(exc),
CloseReason.INVALID_FRAME_PAYLOAD_DATA)
data = (code, reason)
return Frame(frame.opcode, data, frame.frame_finished,
frame.message_finished)
def parse_more_gen(self):
# Consume as much as we can from self._buffer, yielding events, and
# then yield None when we need more data. Or raise ParseFailed.
# XX FIXME this should probably be refactored so that we never see
# disabled extensions in the first place...
self.extensions = [ext for ext in self.extensions if ext.enabled()]
closed = False
while not closed:
frame = self._frame_decoder.process_buffer()
if frame is not None:
if not frame.opcode.iscontrol():
frame = self._message_decoder.process_frame(frame)
elif frame.opcode == Opcode.CLOSE:
frame = self._process_close(frame)
closed = True
yield frame
def receive_bytes(self, data):
self._frame_decoder.receive_bytes(data)
def received_frames(self):
for event in self._parse_more:
if event is None:
break
else:
yield event
def close(self, code=None, reason=None):
payload = bytearray()
if code is None and reason is not None:
raise TypeError("cannot specify a reason without a code")
if code in LOCAL_ONLY_CLOSE_REASONS:
code = CloseReason.NORMAL_CLOSURE
if code is not None:
payload += bytearray(struct.pack('!H', code))
if reason is not None:
payload += _truncate_utf8(reason.encode('utf-8'),
MAX_PAYLOAD_NORMAL - 2)
return self._serialize_frame(Opcode.CLOSE, payload)
def ping(self, payload=b''):
return self._serialize_frame(Opcode.PING, payload)
def pong(self, payload=b''):
return self._serialize_frame(Opcode.PONG, payload)
def send_data(self, payload=b'', fin=True):
if isinstance(payload, (bytes, bytearray, memoryview)):
opcode = Opcode.BINARY
elif isinstance(payload, unicode):
opcode = Opcode.TEXT
payload = payload.encode('utf-8')
else:
raise ValueError('Must provide bytes or text')
if self._outbound_opcode is None:
self._outbound_opcode = opcode
elif self._outbound_opcode is not opcode:
raise TypeError('Data type mismatch inside message')
else:
opcode = Opcode.CONTINUATION
if fin:
self._outbound_opcode = None
return self._serialize_frame(opcode, payload, fin)
def _make_fin_rsv_opcode(self, fin, rsv, opcode):
fin = int(fin) << 7
rsv = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + \
(int(rsv.rsv3) << 4)
opcode = int(opcode)
return fin | rsv | opcode
def _serialize_frame(self, opcode, payload=b'', fin=True):
rsv = RsvBits(False, False, False)
for extension in reversed(self.extensions):
rsv, payload = extension.frame_outbound(self, opcode, rsv, payload,
fin)
fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode)
payload_length = len(payload)
quad_payload = False
if payload_length <= MAX_PAYLOAD_NORMAL:
first_payload = payload_length
second_payload = None
elif payload_length <= MAX_PAYLOAD_TWO_BYTE:
first_payload = PAYLOAD_LENGTH_TWO_BYTE
second_payload = payload_length
else:
first_payload = PAYLOAD_LENGTH_EIGHT_BYTE
second_payload = payload_length
quad_payload = True
if self.client:
first_payload |= 1 << 7
header = bytearray([fin_rsv_opcode, first_payload])
if second_payload is not None:
if opcode.iscontrol():
raise ValueError("payload too long for control frame")
if quad_payload:
header += bytearray(struct.pack('!Q', second_payload))
else:
header += bytearray(struct.pack('!H', second_payload))
if self.client:
# "The masking key is a 32-bit value chosen at random by the
# client. When preparing a masked frame, the client MUST pick a
# fresh masking key from the set of allowed 32-bit values. The
# masking key needs to be unpredictable; thus, the masking key
# MUST be derived from a strong source of entropy, and the masking
# key for a given frame MUST NOT make it simple for a server/proxy
# to predict the masking key for a subsequent frame. The
# unpredictability of the masking key is essential to prevent
# authors of malicious applications from selecting the bytes that
# appear on the wire."
# -- https://tools.ietf.org/html/rfc6455#section-5.3
masking_key = os.urandom(4)
masker = XorMaskerSimple(masking_key)
return header + masking_key + masker.process(payload)
return header + payload

View File

@ -1,10 +1,10 @@
import socket
from OpenSSL import SSL
from wsproto import events
from wsproto.connection import ConnectionType, WSConnection
from wsproto.extensions import PerMessageDeflate
from wsproto.frame_protocol import Opcode
from mitmproxy.contrib.wsproto import events
from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection
from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate
from mitmproxy.contrib.wsproto.frame_protocol import Opcode
from mitmproxy import exceptions
from mitmproxy import flow

View File

@ -65,6 +65,7 @@ setup(
"certifi>=2015.11.20.1", # no semver here - this should always be on the last release!
"click>=6.2, <7",
"cryptography>=2.0,<2.2",
'h11>=0.7.0,<0.8',
"h2>=3.0, <4",
"hyperframe>=5.0, <6",
"kaitaistruct>=0.7, <0.8",