mitmproxy/netlib/websockets/frame.py
Maximilian Hils 73586b1be9 python 3++
2015-09-21 00:44:17 +02:00

314 lines
9.9 KiB
Python

from __future__ import absolute_import
import os
import struct
import io
import warnings
import six
from .protocol import Masker
from netlib import tcp
from netlib import utils
MAX_16_BIT_INT = (1 << 16)
MAX_64_BIT_INT = (1 << 64)
OPCODE = utils.BiDi(
CONTINUE=0x00,
TEXT=0x01,
BINARY=0x02,
CLOSE=0x08,
PING=0x09,
PONG=0x0a
)
class FrameHeader(object):
def __init__(
self,
opcode=OPCODE.TEXT,
payload_length=0,
fin=False,
rsv1=False,
rsv2=False,
rsv3=False,
masking_key=None,
mask=None,
length_code=None
):
if not 0 <= opcode < 2 ** 4:
raise ValueError("opcode must be 0-16")
self.opcode = opcode
self.payload_length = payload_length
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
if length_code is None:
self.length_code = self._make_length_code(self.payload_length)
else:
self.length_code = length_code
if mask is None and masking_key is None:
self.mask = False
self.masking_key = b""
elif mask is None:
self.mask = 1
self.masking_key = masking_key
elif masking_key is None:
self.mask = mask
self.masking_key = os.urandom(4)
else:
self.mask = mask
self.masking_key = masking_key
if self.masking_key and len(self.masking_key) != 4:
raise ValueError("Masking key must be 4 bytes.")
@classmethod
def _make_length_code(self, length):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
if length <= 125:
return length
elif length >= 126 and length <= 65535:
return 126
else:
return 127
def __repr__(self):
vals = [
"ws frame:",
OPCODE.get_name(self.opcode, hex(self.opcode)).lower()
]
flags = []
for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]:
if getattr(self, i):
flags.append(i)
if flags:
vals.extend([":", "|".join(flags)])
if self.masking_key:
vals.append(":key=%s" % repr(self.masking_key))
if self.payload_length:
vals.append(" %s" % utils.pretty_size(self.payload_length))
return "".join(vals)
def human_readable(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return repr(self)
def __bytes__(self):
first_byte = utils.setbit(0, 7, self.fin)
first_byte = utils.setbit(first_byte, 6, self.rsv1)
first_byte = utils.setbit(first_byte, 5, self.rsv2)
first_byte = utils.setbit(first_byte, 4, self.rsv3)
first_byte = first_byte | self.opcode
second_byte = utils.setbit(self.length_code, 7, self.mask)
b = six.int2byte(first_byte) + six.int2byte(second_byte)
if self.payload_length < 126:
pass
elif self.payload_length < MAX_16_BIT_INT:
# '!H' pack as 16 bit unsigned short
# add 2 byte extended payload length
b += struct.pack('!H', self.payload_length)
elif self.payload_length < MAX_64_BIT_INT:
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
b += struct.pack('!Q', self.payload_length)
if self.masking_key:
b += self.masking_key
return b
if six.PY2:
__str__ = __bytes__
def to_bytes(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return bytes(self)
@classmethod
def from_file(cls, fp):
"""
read a websockets frame header
"""
first_byte = six.byte2int(fp.safe_read(1))
second_byte = six.byte2int(fp.safe_read(1))
fin = utils.getbit(first_byte, 7)
rsv1 = utils.getbit(first_byte, 6)
rsv2 = utils.getbit(first_byte, 5)
rsv3 = utils.getbit(first_byte, 4)
# grab right-most 4 bits
opcode = first_byte & 15
mask_bit = utils.getbit(second_byte, 7)
# grab the next 7 bits
length_code = second_byte & 127
# payload_lengthy > 125 indicates you need to read more bytes
# to get the actual payload length
if length_code <= 125:
payload_length = length_code
elif length_code == 126:
payload_length, = struct.unpack("!H", fp.safe_read(2))
elif length_code == 127:
payload_length, = struct.unpack("!Q", fp.safe_read(8))
# masking key only present if mask bit set
if mask_bit == 1:
masking_key = fp.safe_read(4)
else:
masking_key = False
return cls(
fin=fin,
rsv1=rsv1,
rsv2=rsv2,
rsv3=rsv3,
opcode=opcode,
mask=mask_bit,
length_code=length_code,
payload_length=payload_length,
masking_key=masking_key,
)
def __eq__(self, other):
if isinstance(other, FrameHeader):
return bytes(self) == bytes(other)
return False
class Frame(object):
"""
Represents one websockets frame.
Constructor takes human readable forms of the frame components
from_bytes() is also avaliable.
WebSockets Frame as defined in RFC6455
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+
| |Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... |
+---------------------------------------------------------------+
"""
def __init__(self, payload=b"", **kwargs):
self.payload = payload
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
self.header = FrameHeader(**kwargs)
@classmethod
def default(cls, message, from_client=False):
"""
Construct a basic websocket frame from some default values.
Creates a non-fragmented text frame.
"""
if from_client:
mask_bit = 1
masking_key = os.urandom(4)
else:
mask_bit = 0
masking_key = False
return cls(
message,
fin=1, # final frame
opcode=OPCODE.TEXT, # text
mask=mask_bit,
masking_key=masking_key,
)
@classmethod
def from_bytes(cls, bytestring):
"""
Construct a websocket frame from an in-memory bytestring
to construct a frame from a stream of bytes, use from_file() directly
"""
return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))
def __repr__(self):
ret = repr(self.header)
if self.payload:
ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii")
return ret
def human_readable(self):
warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning)
return repr(self)
def __bytes__(self):
"""
Serialize the frame to wire format. Returns a string.
"""
b = bytes(self.header)
if self.header.masking_key:
b += Masker(self.header.masking_key)(self.payload)
else:
b += self.payload
return b
if six.PY2:
__str__ = __bytes__
def to_bytes(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return bytes(self)
def to_file(self, writer):
warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning)
writer.write(bytes(self))
writer.flush()
@classmethod
def from_file(cls, fp):
"""
read a websockets frame sent by a server or client
fp is a "file like" object that could be backed by a network
stream or a disk or an in memory stream reader
"""
header = FrameHeader.from_file(fp)
payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key:
payload = Masker(header.masking_key)(payload)
return cls(
payload,
fin=header.fin,
opcode=header.opcode,
mask=header.mask,
payload_length=header.payload_length,
masking_key=header.masking_key,
rsv1=header.rsv1,
rsv2=header.rsv2,
rsv3=header.rsv3,
length_code=header.length_code
)
def __eq__(self, other):
if isinstance(other, Frame):
return bytes(self) == bytes(other)
return False