websocket: interface refactoring

- Separate out FrameHeader. We need to deal with this separately in many circumstances.
- Simpler equality scheme.
- Bits are now specified by truthiness - we don't care about the integer value.
This means lots of validation is not needed any more.
This commit is contained in:
Aldo Cortesi 2015-04-24 15:09:21 +12:00
parent 3519871f34
commit f22bc0b4c7
3 changed files with 197 additions and 163 deletions

View File

@ -49,3 +49,19 @@ def hexdump(s):
(o, x, cleanBin(part, True)) (o, x, cleanBin(part, True))
) )
return parts return parts
def setbit(byte, offset, value):
"""
Set a bit in a byte to 1 if value is truthy, 0 if not.
"""
if value:
return byte | (1 << offset)
else:
return byte & ~(1 << offset)
def getbit(byte, offset):
mask = 1 << offset
if byte & mask:
return True

View File

@ -1,5 +1,4 @@
from __future__ import absolute_import from __future__ import absolute_import
import base64 import base64
import hashlib import hashlib
import os import os
@ -83,23 +82,6 @@ def server_handshake_headers(key):
) )
def get_payload_length_pair(payload_bytestring):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
actual_length = len(payload_bytestring)
if actual_length <= 125:
length_code = actual_length
elif actual_length >= 126 and actual_length <= 65535:
length_code = 126
else:
length_code = 127
return (length_code, actual_length)
def make_length_code(len): def make_length_code(len):
""" """
A websockets frame contains an initial length_code, and an optional A websockets frame contains an initial length_code, and an optional
@ -132,41 +114,114 @@ def create_server_nonce(client_nonce):
) )
def frame_header_bytes( DEFAULT = object()
opcode = 0, class FrameHeader:
def __init__(
self,
opcode = OPCODE.TEXT,
payload_length = 0, payload_length = 0,
fin = 0, fin = False,
rsv1 = 0, rsv1 = False,
rsv2 = 0, rsv2 = False,
rsv3 = 0, rsv3 = False,
mask = 0,
masking_key = None, masking_key = None,
length_code = None mask = DEFAULT,
length_code = DEFAULT
): ):
first_byte = (fin << 7) | (rsv1 << 6) |\ self.opcode = opcode
(rsv2 << 4) | (rsv3 << 4) | opcode self.payload_length = payload_length
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.mask = mask
self.masking_key = masking_key
self.length_code = length_code
if length_code is None: def to_bytes(self):
length_code = make_length_code(payload_length) first_byte = utils.setbit(0, 7, self.fin)
first_byte = utils.setbit(first_byte, 6, self.rsv1)
first_byte = utils.setbit(first_byte, 5, self.rsv2)
first_byte = utils.setbit(first_byte, 4, self.rsv3)
first_byte = first_byte | self.opcode
if self.length_code is DEFAULT:
length_code = make_length_code(self.payload_length)
else:
length_code = self.length_code
if self.mask is DEFAULT:
mask = bool(self.masking_key)
else:
mask = self.mask
second_byte = (mask << 7) | length_code second_byte = (mask << 7) | length_code
b = chr(first_byte) + chr(second_byte) b = chr(first_byte) + chr(second_byte)
if payload_length < 126: if self.payload_length < 126:
pass pass
elif payload_length < MAX_16_BIT_INT: elif self.payload_length < MAX_16_BIT_INT:
# '!H' pack as 16 bit unsigned short # '!H' pack as 16 bit unsigned short
# add 2 byte extended payload length # add 2 byte extended payload length
b += struct.pack('!H', payload_length) b += struct.pack('!H', self.payload_length)
elif payload_length < MAX_64_BIT_INT: elif self.payload_length < MAX_64_BIT_INT:
# '!Q' = pack as 64 bit unsigned long long # '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length # add 8 bytes extended payload length
b += struct.pack('!Q', payload_length) b += struct.pack('!Q', self.payload_length)
if masking_key is not None: if self.masking_key is not None:
b += masking_key b += self.masking_key
return b return b
@classmethod
def from_file(klass, fp):
"""
read a websockets frame header
"""
first_byte = utils.bytes_to_int(fp.read(1))
second_byte = utils.bytes_to_int(fp.read(1))
fin = utils.getbit(first_byte, 7)
rsv1 = utils.getbit(first_byte, 6)
rsv2 = utils.getbit(first_byte, 5)
rsv3 = utils.getbit(first_byte, 4)
# grab right most 4 bits by and-ing with 00001111
opcode = first_byte & 15
# grab left most bit
mask_bit = second_byte >> 7
# grab the next 7 bits
length_code = second_byte & 127
# payload_lengthy > 125 indicates you need to read more bytes
# to get the actual payload length
if length_code <= 125:
payload_length = length_code
elif length_code == 126:
payload_length = utils.bytes_to_int(fp.read(2))
elif length_code == 127:
payload_length = utils.bytes_to_int(fp.read(8))
# masking key only present if mask bit set
if mask_bit == 1:
masking_key = fp.read(4)
else:
masking_key = None
return klass(
fin = fin,
rsv1 = rsv1,
rsv2 = rsv2,
rsv3 = rsv3,
opcode = opcode,
mask = mask_bit,
length_code = length_code,
payload_length = payload_length,
masking_key = masking_key,
)
def __eq__(self, other):
return self.to_bytes() == other.to_bytes()
class Frame(object): class Frame(object):
""" """
@ -194,27 +249,10 @@ class Frame(object):
| Payload Data continued ... | | Payload Data continued ... |
+---------------------------------------------------------------+ +---------------------------------------------------------------+
""" """
def __init__( def __init__(self, payload = "", **kwargs):
self,
fin, # decmial integer 1 or 0
opcode, # decmial integer 1 - 4
payload = "", # bytestring
masking_key = None, # 32 bit byte string
mask_bit = 0, # decimal integer 1 or 0
payload_length_code = None, # decimal integer 1 - 127
rsv1 = 0, # decimal integer 1 or 0
rsv2 = 0, # decimal integer 1 or 0
rsv3 = 0, # decimal integer 1 or 0
):
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.opcode = opcode
self.mask_bit = mask_bit
self.payload_length_code = payload_length_code
self.masking_key = masking_key
self.payload = payload self.payload = payload
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
self.header = FrameHeader(**kwargs)
@classmethod @classmethod
def default(cls, message, from_client = False): def default(cls, message, from_client = False):
@ -230,10 +268,10 @@ class Frame(object):
masking_key = None masking_key = None
return cls( return cls(
message,
fin = 1, # final frame fin = 1, # final frame
opcode = OPCODE.TEXT, # text opcode = OPCODE.TEXT, # text
mask_bit = mask_bit, mask = mask_bit,
payload = message,
masking_key = masking_key, masking_key = masking_key,
) )
@ -243,30 +281,30 @@ class Frame(object):
Frame has not been corrupted. Frame has not been corrupted.
""" """
constraints = [ constraints = [
0 <= self.fin <= 1, 0 <= self.header.fin <= 1,
0 <= self.rsv1 <= 1, 0 <= self.header.rsv1 <= 1,
0 <= self.rsv2 <= 1, 0 <= self.header.rsv2 <= 1,
0 <= self.rsv3 <= 1, 0 <= self.header.rsv3 <= 1,
1 <= self.opcode <= 4, 1 <= self.header.opcode <= 4,
0 <= self.mask_bit <= 1, 0 <= self.header.mask <= 1,
#1 <= self.payload_length_code <= 127, #1 <= self.payload_length_code <= 127,
1 <= len(self.masking_key) <= 4 if self.mask_bit else True, 1 <= len(self.header.masking_key) <= 4 if self.header.mask else True,
self.masking_key is not None if self.mask_bit else True self.header.masking_key is not None if self.header.mask else True
] ]
if not all(constraints): if not all(constraints):
return False return False
return True return True
def human_readable(self): # pragma: nocover def human_readable(self):
return "\n".join([ return "\n".join([
("fin - " + str(self.fin)), ("fin - " + str(self.header.fin)),
("rsv1 - " + str(self.rsv1)), ("rsv1 - " + str(self.header.rsv1)),
("rsv2 - " + str(self.rsv2)), ("rsv2 - " + str(self.header.rsv2)),
("rsv3 - " + str(self.rsv3)), ("rsv3 - " + str(self.header.rsv3)),
("opcode - " + str(self.opcode)), ("opcode - " + str(self.header.opcode)),
("mask_bit - " + str(self.mask_bit)), ("mask - " + str(self.header.mask)),
("payload_length_code - " + str(self.payload_length_code)), ("length_code - " + str(self.header.length_code)),
("masking_key - " + repr(str(self.masking_key))), ("masking_key - " + repr(str(self.header.masking_key))),
("payload - " + repr(str(self.payload))), ("payload - " + repr(str(self.payload))),
]) ])
@ -284,18 +322,9 @@ class Frame(object):
If you haven't checked is_valid_frame() then there's no guarentees If you haven't checked is_valid_frame() then there's no guarentees
that the serialized bytes will be correct. see safe_to_bytes() that the serialized bytes will be correct. see safe_to_bytes()
""" """
b = frame_header_bytes( b = self.header.to_bytes()
opcode = self.opcode, if self.header.masking_key:
fin = self.fin, b += apply_mask(self.payload, self.header.masking_key)
rsv1 = self.rsv1,
rsv2 = self.rsv2,
rsv3 = self.rsv3,
mask = self.mask_bit,
masking_key = self.masking_key,
payload_length = len(self.payload) if self.payload else 0
)
if self.masking_key:
b += apply_mask(self.payload, self.masking_key)
else: else:
b += self.payload b += self.payload
return b return b
@ -312,66 +341,20 @@ class Frame(object):
fp is a "file like" object that could be backed by a network fp is a "file like" object that could be backed by a network
stream or a disk or an in memory stream reader stream or a disk or an in memory stream reader
""" """
first_byte = utils.bytes_to_int(fp.read(1)) header = FrameHeader.from_file(fp)
second_byte = utils.bytes_to_int(fp.read(1)) payload = fp.read(header.payload_length)
# grab the left most bit if header.mask == 1 and header.masking_key:
fin = first_byte >> 7 payload = apply_mask(payload, header.masking_key)
# grab right most 4 bits by and-ing with 00001111
opcode = first_byte & 15
# grab left most bit
mask_bit = second_byte >> 7
# grab the next 7 bits
payload_length = second_byte & 127
# payload_lengthy > 125 indicates you need to read more bytes
# to get the actual payload length
if payload_length <= 125:
actual_payload_length = payload_length
elif payload_length == 126:
actual_payload_length = utils.bytes_to_int(fp.read(2))
elif payload_length == 127:
actual_payload_length = utils.bytes_to_int(fp.read(8))
# masking key only present if mask bit set
if mask_bit == 1:
masking_key = fp.read(4)
else:
masking_key = None
payload = fp.read(actual_payload_length)
if mask_bit == 1 and masking_key:
payload = apply_mask(payload, masking_key)
return cls( return cls(
fin = fin, payload,
opcode = opcode, fin = header.fin,
mask_bit = mask_bit, opcode = header.opcode,
payload_length_code = payload_length, mask = header.mask,
payload = payload, payload_length = header.payload_length,
masking_key = masking_key, masking_key = header.masking_key,
) )
def __eq__(self, other): def __eq__(self, other):
if self.payload_length_code is None: return self.to_bytes() == other.to_bytes()
myplc = make_length_code(len(self.payload))
else:
myplc = self.payload_length_code
if other.payload_length_code is None:
otherplc = make_length_code(len(other.payload))
else:
otherplc = other.payload_length_code
return (
self.fin == other.fin and
self.rsv1 == other.rsv1 and
self.rsv2 == other.rsv2 and
self.rsv3 == other.rsv3 and
self.opcode == other.opcode and
self.mask_bit == other.mask_bit and
self.masking_key == other.masking_key and
self.payload == other.payload,
myplc == otherplc
)

View File

@ -1,10 +1,9 @@
from netlib import tcp, test, websockets, http import cStringIO
import os import os
from nose.tools import raises from nose.tools import raises
from netlib import tcp, test, websockets, http
def test_frame_header_bytes():
assert websockets.frame_header_bytes()
class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsEchoHandler(tcp.BaseHandler):
@ -119,12 +118,12 @@ class TestWebSockets(test.ServerTestBase):
assert frame.is_valid() assert frame.is_valid()
frame = f() frame = f()
frame.fin = 2 frame.header.fin = 2
assert not frame.is_valid() assert not frame.is_valid()
frame = f() frame = f()
frame.mask_bit = 1 frame.header.mask_bit = 1
frame.masking_key = "foobbarboo" frame.header.masking_key = "foobbarboo"
assert not frame.is_valid() assert not frame.is_valid()
def test_serialization_bijection(self): def test_serialization_bijection(self):
@ -181,3 +180,39 @@ class TestBadHandshake(test.ServerTestBase):
client = WebSocketsClient(("127.0.0.1", self.port)) client = WebSocketsClient(("127.0.0.1", self.port))
client.connect() client.connect()
client.send_message("hello") client.send_message("hello")
class TestFrameHeader:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.FrameHeader(*args, **kwargs)
bytes = f.to_bytes()
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
assert f == f2
round()
round(fin=1)
round(rsv1=1)
round(rsv2=1)
round(rsv3=1)
round(payload_length=1)
round(payload_length=100)
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
round(masking_key="test")
def test_funky(self):
f = websockets.FrameHeader(masking_key="test", mask=False)
bytes = f.to_bytes()
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
assert not f2.mask
class TestFrame:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
bytes = f.to_bytes()
f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes))
assert f == f2
round("test")