websockets: remove validation

We don't really need this any more. The interface is much less error prone
because bit flags are no longer integers, we have a range check on opcode on
header instantiation, and we've deferred length code calculation and so forth
into the byte render methods.
This commit is contained in:
Aldo Cortesi 2015-04-24 15:23:00 +12:00
parent f22bc0b4c7
commit def93ea8ca
2 changed files with 12 additions and 38 deletions

View File

@ -115,6 +115,8 @@ def create_server_nonce(client_nonce):
DEFAULT = object() DEFAULT = object()
class FrameHeader: class FrameHeader:
def __init__( def __init__(
self, self,
@ -128,6 +130,8 @@ class FrameHeader:
mask = DEFAULT, mask = DEFAULT,
length_code = DEFAULT length_code = DEFAULT
): ):
if not 0 <= opcode < 2 ** 4:
raise ValueError("opcode must be 0-16")
self.opcode = opcode self.opcode = opcode
self.payload_length = payload_length self.payload_length = payload_length
self.fin = fin self.fin = fin
@ -275,26 +279,6 @@ class Frame(object):
masking_key = masking_key, masking_key = masking_key,
) )
def is_valid(self):
"""
Validate websocket frame invariants, call at anytime to ensure the
Frame has not been corrupted.
"""
constraints = [
0 <= self.header.fin <= 1,
0 <= self.header.rsv1 <= 1,
0 <= self.header.rsv2 <= 1,
0 <= self.header.rsv3 <= 1,
1 <= self.header.opcode <= 4,
0 <= self.header.mask <= 1,
#1 <= self.payload_length_code <= 127,
1 <= len(self.header.masking_key) <= 4 if self.header.mask else True,
self.header.masking_key is not None if self.header.mask else True
]
if not all(constraints):
return False
return True
def human_readable(self): def human_readable(self):
return "\n".join([ return "\n".join([
("fin - " + str(self.header.fin)), ("fin - " + str(self.header.fin)),

View File

@ -4,6 +4,7 @@ import os
from nose.tools import raises from nose.tools import raises
from netlib import tcp, test, websockets, http from netlib import tcp, test, websockets, http
import tutils
class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsEchoHandler(tcp.BaseHandler):
@ -106,25 +107,7 @@ class TestWebSockets(test.ServerTestBase):
""" """
msg = self.random_bytes() msg = self.random_bytes()
client_frame = websockets.Frame.default(msg, from_client = True) client_frame = websockets.Frame.default(msg, from_client = True)
server_frame = websockets.Frame.default(msg, from_client = False) server_frame = websockets.Frame.default(msg, from_client = False)
assert server_frame.is_valid()
def test_is_valid(self):
def f():
return websockets.Frame.default(self.random_bytes(10), True)
frame = f()
assert frame.is_valid()
frame = f()
frame.header.fin = 2
assert not frame.is_valid()
frame = f()
frame.header.mask_bit = 1
frame.header.masking_key = "foobbarboo"
assert not frame.is_valid()
def test_serialization_bijection(self): def test_serialization_bijection(self):
""" """
@ -207,6 +190,9 @@ class TestFrameHeader:
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
assert not f2.mask assert not f2.mask
def test_violations(self):
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
class TestFrame: class TestFrame:
def test_roundtrip(self): def test_roundtrip(self):
@ -216,3 +202,7 @@ class TestFrame:
f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes)) f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes))
assert f == f2 assert f == f2
round("test") round("test")
def test_human_readable(self):
f = websockets.Frame()
assert f.human_readable()