diff --git a/netlib/websockets.py b/netlib/websockets.py index 016e75c2b..b1afa6206 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -115,6 +115,8 @@ def create_server_nonce(client_nonce): DEFAULT = object() + + class FrameHeader: def __init__( self, @@ -128,6 +130,8 @@ class FrameHeader: mask = DEFAULT, length_code = DEFAULT ): + if not 0 <= opcode < 2 ** 4: + raise ValueError("opcode must be 0-16") self.opcode = opcode self.payload_length = payload_length self.fin = fin @@ -275,26 +279,6 @@ class Frame(object): masking_key = masking_key, ) - def is_valid(self): - """ - Validate websocket frame invariants, call at anytime to ensure the - Frame has not been corrupted. - """ - constraints = [ - 0 <= self.header.fin <= 1, - 0 <= self.header.rsv1 <= 1, - 0 <= self.header.rsv2 <= 1, - 0 <= self.header.rsv3 <= 1, - 1 <= self.header.opcode <= 4, - 0 <= self.header.mask <= 1, - #1 <= self.payload_length_code <= 127, - 1 <= len(self.header.masking_key) <= 4 if self.header.mask else True, - self.header.masking_key is not None if self.header.mask else True - ] - if not all(constraints): - return False - return True - def human_readable(self): return "\n".join([ ("fin - " + str(self.header.fin)), diff --git a/test/test_websockets.py b/test/test_websockets.py index 06876e0b8..215b39581 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -4,6 +4,7 @@ import os from nose.tools import raises from netlib import tcp, test, websockets, http +import tutils class WebSocketsEchoHandler(tcp.BaseHandler): @@ -106,25 +107,7 @@ class TestWebSockets(test.ServerTestBase): """ msg = self.random_bytes() client_frame = websockets.Frame.default(msg, from_client = True) - server_frame = websockets.Frame.default(msg, from_client = False) - assert server_frame.is_valid() - - def test_is_valid(self): - def f(): - return websockets.Frame.default(self.random_bytes(10), True) - - frame = f() - assert frame.is_valid() - - frame = f() - frame.header.fin = 2 - assert not frame.is_valid() - - frame = f() - frame.header.mask_bit = 1 - frame.header.masking_key = "foobbarboo" - assert not frame.is_valid() def test_serialization_bijection(self): """ @@ -207,6 +190,9 @@ class TestFrameHeader: f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) assert not f2.mask + def test_violations(self): + tutils.raises("opcode", websockets.FrameHeader, opcode=17) + class TestFrame: def test_roundtrip(self): @@ -216,3 +202,7 @@ class TestFrame: f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes)) assert f == f2 round("test") + + def test_human_readable(self): + f = websockets.Frame() + assert f.human_readable()