diff --git a/netlib/websockets.py b/netlib/websockets.py index 0cd4dba15..1e9c96cca 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -22,11 +22,17 @@ from . import utils, odict # RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" +MAX_16_BIT_INT = (1 << 16) +MAX_64_BIT_INT = (1 << 64) -class CONST(object): - MAX_16_BIT_INT = (1 << 16) - MAX_64_BIT_INT = (1 << 64) +class OPCODE: + CONTINUE = 0x00 + TEXT = 0x01 + BINARY = 0x02 + CLOSE = 0x08 + PING = 0x09 + PONG = 0x0a class Frame(object): @@ -101,7 +107,7 @@ class Frame(object): return cls( fin = 1, # final frame - opcode = 1, # text + opcode = OPCODE.TEXT, # text mask_bit = mask_bit, payload_length_code = length_code, payload = payload, @@ -115,28 +121,27 @@ class Frame(object): Validate websocket frame invariants, call at anytime to ensure the Frame has not been corrupted. """ - try: - assert 0 <= self.fin <= 1 - assert 0 <= self.rsv1 <= 1 - assert 0 <= self.rsv2 <= 1 - assert 0 <= self.rsv3 <= 1 - assert 1 <= self.opcode <= 4 - assert 0 <= self.mask_bit <= 1 - assert 1 <= self.payload_length_code <= 127 - - if self.mask_bit == 1: - assert 1 <= len(self.masking_key) <= 4 - else: - assert self.masking_key is None - - assert self.actual_payload_length == len(self.payload) - - if self.payload is not None and self.masking_key is not None: - assert apply_mask(self.payload, self.masking_key) == self.decoded_payload - - return True - except AssertionError: + constraints = [ + 0 <= self.fin <= 1, + 0 <= self.rsv1 <= 1, + 0 <= self.rsv2 <= 1, + 0 <= self.rsv3 <= 1, + 1 <= self.opcode <= 4, + 0 <= self.mask_bit <= 1, + 1 <= self.payload_length_code <= 127, + self.actual_payload_length == len(self.payload) + ] + if not all(constraints): return False + elif self.mask_bit == 1 and not 1 <= len(self.masking_key) <= 4: + return False + elif self.mask_bit == 0 and self.masking_key is not None: + return False + elif self.payload and self.masking_key: + decoded = apply_mask(self.payload, self.masking_key) + if decoded != self.decoded_payload: + return False + return True def human_readable(self): # pragma: nocover return "\n".join([ @@ -192,11 +197,11 @@ class Frame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < CONST.MAX_16_BIT_INT: + elif self.actual_payload_length < MAX_16_BIT_INT: # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length b += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < CONST.MAX_64_BIT_INT: + elif self.actual_payload_length < MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length b += struct.pack('!Q', self.actual_payload_length) @@ -212,15 +217,15 @@ class Frame(object): writer.flush() @classmethod - def from_file(cls, reader): + def from_file(cls, fp): """ read a websockets frame sent by a server or client - reader is a "file like" object that could be backed by a network + fp is a "file like" object that could be backed by a network stream or a disk or an in memory stream reader """ - first_byte = utils.bytes_to_int(reader.read(1)) - second_byte = utils.bytes_to_int(reader.read(1)) + first_byte = utils.bytes_to_int(fp.read(1)) + second_byte = utils.bytes_to_int(fp.read(1)) # grab the left most bit fin = first_byte >> 7 @@ -237,18 +242,18 @@ class Frame(object): actual_payload_length = payload_length elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(reader.read(2)) + actual_payload_length = utils.bytes_to_int(fp.read(2)) elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(reader.read(8)) + actual_payload_length = utils.bytes_to_int(fp.read(8)) # masking key only present if mask bit set if mask_bit == 1: - masking_key = reader.read(4) + masking_key = fp.read(4) else: masking_key = None - payload = reader.read(actual_payload_length) + payload = fp.read(actual_payload_length) if mask_bit == 1: decoded_payload = apply_mask(payload, masking_key) diff --git a/test/test_websockets.py b/test/test_websockets.py index 3fc67dfee..9e205e701 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -108,6 +108,33 @@ class TestWebSockets(test.ServerTestBase): server_frame = websockets.Frame.default(msg, from_client = False) assert server_frame.is_valid() + def test_is_valid(self): + def f(): + return websockets.Frame.default(self.random_bytes(10), True) + + frame = f() + assert frame.is_valid() + + frame = f() + frame.fin = 2 + assert not frame.is_valid() + + frame = f() + frame.mask_bit = 1 + frame.masking_key = "foobbarboo" + assert not frame.is_valid() + + frame = f() + frame.mask_bit = 0 + frame.masking_key = "foob" + assert not frame.is_valid() + + frame = f() + frame.masking_key = "foob" + frame.decoded_payload = "xxxx" + assert not frame.is_valid() + + def test_serialization_bijection(self): """ Ensure that various frame types can be serialized/deserialized back