diff --git a/netlib/utils.py b/netlib/utils.py index 66bbdb5e6..44bed43ab 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -49,3 +49,19 @@ def hexdump(s): (o, x, cleanBin(part, True)) ) return parts + + +def setbit(byte, offset, value): + """ + Set a bit in a byte to 1 if value is truthy, 0 if not. + """ + if value: + return byte | (1 << offset) + else: + return byte & ~(1 << offset) + + +def getbit(byte, offset): + mask = 1 << offset + if byte & mask: + return True diff --git a/netlib/websockets.py b/netlib/websockets.py index 7c127563a..016e75c2b 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -1,5 +1,4 @@ from __future__ import absolute_import - import base64 import hashlib import os @@ -83,23 +82,6 @@ def server_handshake_headers(key): ) -def get_payload_length_pair(payload_bytestring): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - actual_length = len(payload_bytestring) - - if actual_length <= 125: - length_code = actual_length - elif actual_length >= 126 and actual_length <= 65535: - length_code = 126 - else: - length_code = 127 - return (length_code, actual_length) - - def make_length_code(len): """ A websockets frame contains an initial length_code, and an optional @@ -132,40 +114,113 @@ def create_server_nonce(client_nonce): ) -def frame_header_bytes( - opcode = 0, - payload_length = 0, - fin = 0, - rsv1 = 0, - rsv2 = 0, - rsv3 = 0, - mask = 0, - masking_key = None, - length_code = None -): - first_byte = (fin << 7) | (rsv1 << 6) |\ - (rsv2 << 4) | (rsv3 << 4) | opcode +DEFAULT = object() +class FrameHeader: + def __init__( + self, + opcode = OPCODE.TEXT, + payload_length = 0, + fin = False, + rsv1 = False, + rsv2 = False, + rsv3 = False, + masking_key = None, + mask = DEFAULT, + length_code = DEFAULT + ): + self.opcode = opcode + self.payload_length = payload_length + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.mask = mask + self.masking_key = masking_key + self.length_code = length_code - if length_code is None: - length_code = make_length_code(payload_length) + def to_bytes(self): + first_byte = utils.setbit(0, 7, self.fin) + first_byte = utils.setbit(first_byte, 6, self.rsv1) + first_byte = utils.setbit(first_byte, 5, self.rsv2) + first_byte = utils.setbit(first_byte, 4, self.rsv3) + first_byte = first_byte | self.opcode - second_byte = (mask << 7) | length_code + if self.length_code is DEFAULT: + length_code = make_length_code(self.payload_length) + else: + length_code = self.length_code - b = chr(first_byte) + chr(second_byte) + if self.mask is DEFAULT: + mask = bool(self.masking_key) + else: + mask = self.mask - if payload_length < 126: - pass - elif payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', payload_length) - elif payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', payload_length) - if masking_key is not None: - b += masking_key - return b + second_byte = (mask << 7) | length_code + + b = chr(first_byte) + chr(second_byte) + + if self.payload_length < 126: + pass + elif self.payload_length < MAX_16_BIT_INT: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + b += struct.pack('!H', self.payload_length) + elif self.payload_length < MAX_64_BIT_INT: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + b += struct.pack('!Q', self.payload_length) + if self.masking_key is not None: + b += self.masking_key + return b + + @classmethod + def from_file(klass, fp): + """ + read a websockets frame header + """ + first_byte = utils.bytes_to_int(fp.read(1)) + second_byte = utils.bytes_to_int(fp.read(1)) + + fin = utils.getbit(first_byte, 7) + rsv1 = utils.getbit(first_byte, 6) + rsv2 = utils.getbit(first_byte, 5) + rsv3 = utils.getbit(first_byte, 4) + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + length_code = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if length_code <= 125: + payload_length = length_code + elif length_code == 126: + payload_length = utils.bytes_to_int(fp.read(2)) + elif length_code == 127: + payload_length = utils.bytes_to_int(fp.read(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = fp.read(4) + else: + masking_key = None + + return klass( + fin = fin, + rsv1 = rsv1, + rsv2 = rsv2, + rsv3 = rsv3, + opcode = opcode, + mask = mask_bit, + length_code = length_code, + payload_length = payload_length, + masking_key = masking_key, + ) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() class Frame(object): @@ -194,27 +249,10 @@ class Frame(object): | Payload Data continued ... | +---------------------------------------------------------------+ """ - def __init__( - self, - fin, # decmial integer 1 or 0 - opcode, # decmial integer 1 - 4 - payload = "", # bytestring - masking_key = None, # 32 bit byte string - mask_bit = 0, # decimal integer 1 or 0 - payload_length_code = None, # decimal integer 1 - 127 - rsv1 = 0, # decimal integer 1 or 0 - rsv2 = 0, # decimal integer 1 or 0 - rsv3 = 0, # decimal integer 1 or 0 - ): - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - self.opcode = opcode - self.mask_bit = mask_bit - self.payload_length_code = payload_length_code - self.masking_key = masking_key + def __init__(self, payload = "", **kwargs): self.payload = payload + kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) + self.header = FrameHeader(**kwargs) @classmethod def default(cls, message, from_client = False): @@ -230,10 +268,10 @@ class Frame(object): masking_key = None return cls( + message, fin = 1, # final frame opcode = OPCODE.TEXT, # text - mask_bit = mask_bit, - payload = message, + mask = mask_bit, masking_key = masking_key, ) @@ -243,30 +281,30 @@ class Frame(object): Frame has not been corrupted. """ constraints = [ - 0 <= self.fin <= 1, - 0 <= self.rsv1 <= 1, - 0 <= self.rsv2 <= 1, - 0 <= self.rsv3 <= 1, - 1 <= self.opcode <= 4, - 0 <= self.mask_bit <= 1, + 0 <= self.header.fin <= 1, + 0 <= self.header.rsv1 <= 1, + 0 <= self.header.rsv2 <= 1, + 0 <= self.header.rsv3 <= 1, + 1 <= self.header.opcode <= 4, + 0 <= self.header.mask <= 1, #1 <= self.payload_length_code <= 127, - 1 <= len(self.masking_key) <= 4 if self.mask_bit else True, - self.masking_key is not None if self.mask_bit else True + 1 <= len(self.header.masking_key) <= 4 if self.header.mask else True, + self.header.masking_key is not None if self.header.mask else True ] if not all(constraints): return False return True - def human_readable(self): # pragma: nocover + def human_readable(self): return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + repr(str(self.masking_key))), + ("fin - " + str(self.header.fin)), + ("rsv1 - " + str(self.header.rsv1)), + ("rsv2 - " + str(self.header.rsv2)), + ("rsv3 - " + str(self.header.rsv3)), + ("opcode - " + str(self.header.opcode)), + ("mask - " + str(self.header.mask)), + ("length_code - " + str(self.header.length_code)), + ("masking_key - " + repr(str(self.header.masking_key))), ("payload - " + repr(str(self.payload))), ]) @@ -284,18 +322,9 @@ class Frame(object): If you haven't checked is_valid_frame() then there's no guarentees that the serialized bytes will be correct. see safe_to_bytes() """ - b = frame_header_bytes( - opcode = self.opcode, - fin = self.fin, - rsv1 = self.rsv1, - rsv2 = self.rsv2, - rsv3 = self.rsv3, - mask = self.mask_bit, - masking_key = self.masking_key, - payload_length = len(self.payload) if self.payload else 0 - ) - if self.masking_key: - b += apply_mask(self.payload, self.masking_key) + b = self.header.to_bytes() + if self.header.masking_key: + b += apply_mask(self.payload, self.header.masking_key) else: b += self.payload return b @@ -312,66 +341,20 @@ class Frame(object): fp is a "file like" object that could be backed by a network stream or a disk or an in memory stream reader """ - first_byte = utils.bytes_to_int(fp.read(1)) - second_byte = utils.bytes_to_int(fp.read(1)) + header = FrameHeader.from_file(fp) + payload = fp.read(header.payload_length) - # grab the left most bit - fin = first_byte >> 7 - # grab right most 4 bits by and-ing with 00001111 - opcode = first_byte & 15 - # grab left most bit - mask_bit = second_byte >> 7 - # grab the next 7 bits - payload_length = second_byte & 127 - - # payload_lengthy > 125 indicates you need to read more bytes - # to get the actual payload length - if payload_length <= 125: - actual_payload_length = payload_length - - elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(fp.read(2)) - - elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(fp.read(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = fp.read(4) - else: - masking_key = None - - payload = fp.read(actual_payload_length) - - if mask_bit == 1 and masking_key: - payload = apply_mask(payload, masking_key) + if header.mask == 1 and header.masking_key: + payload = apply_mask(payload, header.masking_key) return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, + payload, + fin = header.fin, + opcode = header.opcode, + mask = header.mask, + payload_length = header.payload_length, + masking_key = header.masking_key, ) def __eq__(self, other): - if self.payload_length_code is None: - myplc = make_length_code(len(self.payload)) - else: - myplc = self.payload_length_code - if other.payload_length_code is None: - otherplc = make_length_code(len(other.payload)) - else: - otherplc = other.payload_length_code - return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.masking_key == other.masking_key and - self.payload == other.payload, - myplc == otherplc - ) + return self.to_bytes() == other.to_bytes() diff --git a/test/test_websockets.py b/test/test_websockets.py index bf8ec5cda..06876e0b8 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,10 +1,9 @@ -from netlib import tcp, test, websockets, http +import cStringIO import os + from nose.tools import raises - -def test_frame_header_bytes(): - assert websockets.frame_header_bytes() +from netlib import tcp, test, websockets, http class WebSocketsEchoHandler(tcp.BaseHandler): @@ -119,12 +118,12 @@ class TestWebSockets(test.ServerTestBase): assert frame.is_valid() frame = f() - frame.fin = 2 + frame.header.fin = 2 assert not frame.is_valid() frame = f() - frame.mask_bit = 1 - frame.masking_key = "foobbarboo" + frame.header.mask_bit = 1 + frame.header.masking_key = "foobbarboo" assert not frame.is_valid() def test_serialization_bijection(self): @@ -181,3 +180,39 @@ class TestBadHandshake(test.ServerTestBase): client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message("hello") + + +class TestFrameHeader: + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.FrameHeader(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) + assert f == f2 + round() + round(fin=1) + round(rsv1=1) + round(rsv2=1) + round(rsv3=1) + round(payload_length=1) + round(payload_length=100) + round(payload_length=1000) + round(payload_length=10000) + round(opcode=websockets.OPCODE.PING) + round(masking_key="test") + + def test_funky(self): + f = websockets.FrameHeader(masking_key="test", mask=False) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) + assert not f2.mask + + +class TestFrame: + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.Frame(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes)) + assert f == f2 + round("test")