websocket: interface refactoring

- Separate out FrameHeader. We need to deal with this separately in many circumstances.
- Simpler equality scheme.
- Bits are now specified by truthiness - we don't care about the integer value.
This means lots of validation is not needed any more.
This commit is contained in:
Aldo Cortesi 2015-04-24 15:09:21 +12:00
parent 3519871f34
commit f22bc0b4c7
3 changed files with 197 additions and 163 deletions

View File

@ -49,3 +49,19 @@ def hexdump(s):
(o, x, cleanBin(part, True))
)
return parts
def setbit(byte, offset, value):
"""
Set a bit in a byte to 1 if value is truthy, 0 if not.
"""
if value:
return byte | (1 << offset)
else:
return byte & ~(1 << offset)
def getbit(byte, offset):
mask = 1 << offset
if byte & mask:
return True

View File

@ -1,5 +1,4 @@
from __future__ import absolute_import
import base64
import hashlib
import os
@ -83,23 +82,6 @@ def server_handshake_headers(key):
)
def get_payload_length_pair(payload_bytestring):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
actual_length = len(payload_bytestring)
if actual_length <= 125:
length_code = actual_length
elif actual_length >= 126 and actual_length <= 65535:
length_code = 126
else:
length_code = 127
return (length_code, actual_length)
def make_length_code(len):
"""
A websockets frame contains an initial length_code, and an optional
@ -132,41 +114,114 @@ def create_server_nonce(client_nonce):
)
def frame_header_bytes(
opcode = 0,
DEFAULT = object()
class FrameHeader:
def __init__(
self,
opcode = OPCODE.TEXT,
payload_length = 0,
fin = 0,
rsv1 = 0,
rsv2 = 0,
rsv3 = 0,
mask = 0,
fin = False,
rsv1 = False,
rsv2 = False,
rsv3 = False,
masking_key = None,
length_code = None
mask = DEFAULT,
length_code = DEFAULT
):
first_byte = (fin << 7) | (rsv1 << 6) |\
(rsv2 << 4) | (rsv3 << 4) | opcode
self.opcode = opcode
self.payload_length = payload_length
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.mask = mask
self.masking_key = masking_key
self.length_code = length_code
if length_code is None:
length_code = make_length_code(payload_length)
def to_bytes(self):
first_byte = utils.setbit(0, 7, self.fin)
first_byte = utils.setbit(first_byte, 6, self.rsv1)
first_byte = utils.setbit(first_byte, 5, self.rsv2)
first_byte = utils.setbit(first_byte, 4, self.rsv3)
first_byte = first_byte | self.opcode
if self.length_code is DEFAULT:
length_code = make_length_code(self.payload_length)
else:
length_code = self.length_code
if self.mask is DEFAULT:
mask = bool(self.masking_key)
else:
mask = self.mask
second_byte = (mask << 7) | length_code
b = chr(first_byte) + chr(second_byte)
if payload_length < 126:
if self.payload_length < 126:
pass
elif payload_length < MAX_16_BIT_INT:
elif self.payload_length < MAX_16_BIT_INT:
# '!H' pack as 16 bit unsigned short
# add 2 byte extended payload length
b += struct.pack('!H', payload_length)
elif payload_length < MAX_64_BIT_INT:
b += struct.pack('!H', self.payload_length)
elif self.payload_length < MAX_64_BIT_INT:
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
b += struct.pack('!Q', payload_length)
if masking_key is not None:
b += masking_key
b += struct.pack('!Q', self.payload_length)
if self.masking_key is not None:
b += self.masking_key
return b
@classmethod
def from_file(klass, fp):
"""
read a websockets frame header
"""
first_byte = utils.bytes_to_int(fp.read(1))
second_byte = utils.bytes_to_int(fp.read(1))
fin = utils.getbit(first_byte, 7)
rsv1 = utils.getbit(first_byte, 6)
rsv2 = utils.getbit(first_byte, 5)
rsv3 = utils.getbit(first_byte, 4)
# grab right most 4 bits by and-ing with 00001111
opcode = first_byte & 15
# grab left most bit
mask_bit = second_byte >> 7
# grab the next 7 bits
length_code = second_byte & 127
# payload_lengthy > 125 indicates you need to read more bytes
# to get the actual payload length
if length_code <= 125:
payload_length = length_code
elif length_code == 126:
payload_length = utils.bytes_to_int(fp.read(2))
elif length_code == 127:
payload_length = utils.bytes_to_int(fp.read(8))
# masking key only present if mask bit set
if mask_bit == 1:
masking_key = fp.read(4)
else:
masking_key = None
return klass(
fin = fin,
rsv1 = rsv1,
rsv2 = rsv2,
rsv3 = rsv3,
opcode = opcode,
mask = mask_bit,
length_code = length_code,
payload_length = payload_length,
masking_key = masking_key,
)
def __eq__(self, other):
return self.to_bytes() == other.to_bytes()
class Frame(object):
"""
@ -194,27 +249,10 @@ class Frame(object):
| Payload Data continued ... |
+---------------------------------------------------------------+
"""
def __init__(
self,
fin, # decmial integer 1 or 0
opcode, # decmial integer 1 - 4
payload = "", # bytestring
masking_key = None, # 32 bit byte string
mask_bit = 0, # decimal integer 1 or 0
payload_length_code = None, # decimal integer 1 - 127
rsv1 = 0, # decimal integer 1 or 0
rsv2 = 0, # decimal integer 1 or 0
rsv3 = 0, # decimal integer 1 or 0
):
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.opcode = opcode
self.mask_bit = mask_bit
self.payload_length_code = payload_length_code
self.masking_key = masking_key
def __init__(self, payload = "", **kwargs):
self.payload = payload
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
self.header = FrameHeader(**kwargs)
@classmethod
def default(cls, message, from_client = False):
@ -230,10 +268,10 @@ class Frame(object):
masking_key = None
return cls(
message,
fin = 1, # final frame
opcode = OPCODE.TEXT, # text
mask_bit = mask_bit,
payload = message,
mask = mask_bit,
masking_key = masking_key,
)
@ -243,30 +281,30 @@ class Frame(object):
Frame has not been corrupted.
"""
constraints = [
0 <= self.fin <= 1,
0 <= self.rsv1 <= 1,
0 <= self.rsv2 <= 1,
0 <= self.rsv3 <= 1,
1 <= self.opcode <= 4,
0 <= self.mask_bit <= 1,
0 <= self.header.fin <= 1,
0 <= self.header.rsv1 <= 1,
0 <= self.header.rsv2 <= 1,
0 <= self.header.rsv3 <= 1,
1 <= self.header.opcode <= 4,
0 <= self.header.mask <= 1,
#1 <= self.payload_length_code <= 127,
1 <= len(self.masking_key) <= 4 if self.mask_bit else True,
self.masking_key is not None if self.mask_bit else True
1 <= len(self.header.masking_key) <= 4 if self.header.mask else True,
self.header.masking_key is not None if self.header.mask else True
]
if not all(constraints):
return False
return True
def human_readable(self): # pragma: nocover
def human_readable(self):
return "\n".join([
("fin - " + str(self.fin)),
("rsv1 - " + str(self.rsv1)),
("rsv2 - " + str(self.rsv2)),
("rsv3 - " + str(self.rsv3)),
("opcode - " + str(self.opcode)),
("mask_bit - " + str(self.mask_bit)),
("payload_length_code - " + str(self.payload_length_code)),
("masking_key - " + repr(str(self.masking_key))),
("fin - " + str(self.header.fin)),
("rsv1 - " + str(self.header.rsv1)),
("rsv2 - " + str(self.header.rsv2)),
("rsv3 - " + str(self.header.rsv3)),
("opcode - " + str(self.header.opcode)),
("mask - " + str(self.header.mask)),
("length_code - " + str(self.header.length_code)),
("masking_key - " + repr(str(self.header.masking_key))),
("payload - " + repr(str(self.payload))),
])
@ -284,18 +322,9 @@ class Frame(object):
If you haven't checked is_valid_frame() then there's no guarentees
that the serialized bytes will be correct. see safe_to_bytes()
"""
b = frame_header_bytes(
opcode = self.opcode,
fin = self.fin,
rsv1 = self.rsv1,
rsv2 = self.rsv2,
rsv3 = self.rsv3,
mask = self.mask_bit,
masking_key = self.masking_key,
payload_length = len(self.payload) if self.payload else 0
)
if self.masking_key:
b += apply_mask(self.payload, self.masking_key)
b = self.header.to_bytes()
if self.header.masking_key:
b += apply_mask(self.payload, self.header.masking_key)
else:
b += self.payload
return b
@ -312,66 +341,20 @@ class Frame(object):
fp is a "file like" object that could be backed by a network
stream or a disk or an in memory stream reader
"""
first_byte = utils.bytes_to_int(fp.read(1))
second_byte = utils.bytes_to_int(fp.read(1))
header = FrameHeader.from_file(fp)
payload = fp.read(header.payload_length)
# grab the left most bit
fin = first_byte >> 7
# grab right most 4 bits by and-ing with 00001111
opcode = first_byte & 15
# grab left most bit
mask_bit = second_byte >> 7
# grab the next 7 bits
payload_length = second_byte & 127
# payload_lengthy > 125 indicates you need to read more bytes
# to get the actual payload length
if payload_length <= 125:
actual_payload_length = payload_length
elif payload_length == 126:
actual_payload_length = utils.bytes_to_int(fp.read(2))
elif payload_length == 127:
actual_payload_length = utils.bytes_to_int(fp.read(8))
# masking key only present if mask bit set
if mask_bit == 1:
masking_key = fp.read(4)
else:
masking_key = None
payload = fp.read(actual_payload_length)
if mask_bit == 1 and masking_key:
payload = apply_mask(payload, masking_key)
if header.mask == 1 and header.masking_key:
payload = apply_mask(payload, header.masking_key)
return cls(
fin = fin,
opcode = opcode,
mask_bit = mask_bit,
payload_length_code = payload_length,
payload = payload,
masking_key = masking_key,
payload,
fin = header.fin,
opcode = header.opcode,
mask = header.mask,
payload_length = header.payload_length,
masking_key = header.masking_key,
)
def __eq__(self, other):
if self.payload_length_code is None:
myplc = make_length_code(len(self.payload))
else:
myplc = self.payload_length_code
if other.payload_length_code is None:
otherplc = make_length_code(len(other.payload))
else:
otherplc = other.payload_length_code
return (
self.fin == other.fin and
self.rsv1 == other.rsv1 and
self.rsv2 == other.rsv2 and
self.rsv3 == other.rsv3 and
self.opcode == other.opcode and
self.mask_bit == other.mask_bit and
self.masking_key == other.masking_key and
self.payload == other.payload,
myplc == otherplc
)
return self.to_bytes() == other.to_bytes()

View File

@ -1,10 +1,9 @@
from netlib import tcp, test, websockets, http
import cStringIO
import os
from nose.tools import raises
def test_frame_header_bytes():
assert websockets.frame_header_bytes()
from netlib import tcp, test, websockets, http
class WebSocketsEchoHandler(tcp.BaseHandler):
@ -119,12 +118,12 @@ class TestWebSockets(test.ServerTestBase):
assert frame.is_valid()
frame = f()
frame.fin = 2
frame.header.fin = 2
assert not frame.is_valid()
frame = f()
frame.mask_bit = 1
frame.masking_key = "foobbarboo"
frame.header.mask_bit = 1
frame.header.masking_key = "foobbarboo"
assert not frame.is_valid()
def test_serialization_bijection(self):
@ -181,3 +180,39 @@ class TestBadHandshake(test.ServerTestBase):
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message("hello")
class TestFrameHeader:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.FrameHeader(*args, **kwargs)
bytes = f.to_bytes()
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
assert f == f2
round()
round(fin=1)
round(rsv1=1)
round(rsv2=1)
round(rsv3=1)
round(payload_length=1)
round(payload_length=100)
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
round(masking_key="test")
def test_funky(self):
f = websockets.FrameHeader(masking_key="test", mask=False)
bytes = f.to_bytes()
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
assert not f2.mask
class TestFrame:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
bytes = f.to_bytes()
f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes))
assert f == f2
round("test")