websockets: more whitespace, WebSocketFrame -> Frame

This commit is contained in:
Aldo Cortesi 2015-04-17 14:29:20 +12:00
parent 488c25d812
commit 7defb5be86
3 changed files with 81 additions and 76 deletions

View File

@ -9,7 +9,7 @@ import os
# Simple websocket client and servers that are used to exercise the functionality in websockets.py
# These are *not* fully RFC6455 compliant
class WebSocketsEchoHandler(tcp.BaseHandler):
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(connection, address, server)
self.handshake_done = False
@ -22,14 +22,14 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
self.read_next_message()
def read_next_message(self):
decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload
decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload
self.on_message(decoded)
def send_message(self, message):
frame = ws.WebSocketsFrame.default(message, from_client = False)
frame = ws.Frame.default(message, from_client = False)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
def handshake(self):
client_hs = ws.read_handshake(self.rfile.read, 1)
key = ws.process_handshake_from_client(client_hs)
@ -72,9 +72,9 @@ class WebSocketsClient(tcp.TCPClient):
self.close()
def read_next_message(self):
return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload
return ws.Frame.from_byte_stream(self.rfile.read).payload
def send_message(self, message):
frame = ws.WebSocketsFrame.default(message, from_client = True)
frame = ws.Frame.default(message, from_client = True)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()

View File

@ -29,7 +29,7 @@ class WebSocketFrameValidationException(Exception):
pass
class WebSocketsFrame(object):
class Frame(object):
"""
Represents one websockets frame.
Constructor takes human readable forms of the frame components
@ -98,29 +98,29 @@ class WebSocketsFrame(object):
length_code, actual_length = get_payload_length_pair(message)
if from_client:
mask_bit = 1
mask_bit = 1
masking_key = random_masking_key()
payload = apply_mask(message, masking_key)
payload = apply_mask(message, masking_key)
else:
mask_bit = 0
mask_bit = 0
masking_key = None
payload = message
payload = message
return cls(
fin = 1, # final frame
opcode = 1, # text
mask_bit = mask_bit,
payload_length_code = length_code,
payload = payload,
masking_key = masking_key,
decoded_payload = message,
fin = 1, # final frame
opcode = 1, # text
mask_bit = mask_bit,
payload_length_code = length_code,
payload = payload,
masking_key = masking_key,
decoded_payload = message,
actual_payload_length = actual_length
)
def is_valid(self):
"""
Validate websocket frame invariants, call at anytime to ensure the
WebSocketsFrame has not been corrupted.
Validate websocket frame invariants, call at anytime to ensure the
Frame has not been corrupted.
"""
try:
assert 0 <= self.fin <= 1
@ -147,17 +147,18 @@ class WebSocketsFrame(object):
def human_readable(self):
return "\n".join([
("fin - " + str(self.fin)),
("rsv1 - " + str(self.rsv1)),
("rsv2 - " + str(self.rsv2)),
("rsv3 - " + str(self.rsv3)),
("opcode - " + str(self.opcode)),
("mask_bit - " + str(self.mask_bit)),
("payload_length_code - " + str(self.payload_length_code)),
("masking_key - " + str(self.masking_key)),
("payload - " + str(self.payload)),
("decoded_payload - " + str(self.decoded_payload)),
("actual_payload_length - " + str(self.actual_payload_length))])
("fin - " + str(self.fin)),
("rsv1 - " + str(self.rsv1)),
("rsv2 - " + str(self.rsv2)),
("rsv3 - " + str(self.rsv3)),
("opcode - " + str(self.opcode)),
("mask_bit - " + str(self.mask_bit)),
("payload_length_code - " + str(self.payload_length_code)),
("masking_key - " + str(self.masking_key)),
("payload - " + str(self.payload)),
("decoded_payload - " + str(self.decoded_payload)),
("actual_payload_length - " + str(self.actual_payload_length))
])
def safe_to_bytes(self):
if self.is_valid():
@ -167,11 +168,10 @@ class WebSocketsFrame(object):
def to_bytes(self):
"""
Serialize the frame back into the wire format, returns a bytestring If
you haven't checked is_valid_frame() then there's no guarentees that
the serialized bytes will be correct. see safe_to_bytes()
Serialize the frame back into the wire format, returns a bytestring
If you haven't checked is_valid_frame() then there's no guarentees
that the serialized bytes will be correct. see safe_to_bytes()
"""
max_16_bit_int = (1 << 16)
max_64_bit_int = (1 << 63)
@ -199,13 +199,10 @@ class WebSocketsFrame(object):
if self.actual_payload_length < 126:
pass
elif self.actual_payload_length < max_16_bit_int:
# '!H' pack as 16 bit unsigned short
# add 2 byte extended payload length
bytes += struct.pack('!H', self.actual_payload_length)
elif self.actual_payload_length < max_64_bit_int:
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
@ -215,7 +212,6 @@ class WebSocketsFrame(object):
bytes += self.masking_key
bytes += self.payload # already will be encoded if neccessary
return bytes
@classmethod
@ -264,29 +260,31 @@ class WebSocketsFrame(object):
decoded_payload = payload
return cls(
fin = fin,
opcode = opcode,
mask_bit = mask_bit,
payload_length_code = payload_length,
payload = payload,
masking_key = masking_key,
decoded_payload = decoded_payload,
fin = fin,
opcode = opcode,
mask_bit = mask_bit,
payload_length_code = payload_length,
payload = payload,
masking_key = masking_key,
decoded_payload = decoded_payload,
actual_payload_length = actual_payload_length
)
def __eq__(self, other):
return (
self.fin == other.fin and
self.rsv1 == other.rsv1 and
self.rsv2 == other.rsv2 and
self.rsv3 == other.rsv3 and
self.opcode == other.opcode and
self.mask_bit == other.mask_bit and
self.payload_length_code == other.payload_length_code and
self.masking_key == other.masking_key and
self.payload == other.payload and
self.decoded_payload == other.decoded_payload and
self.actual_payload_length == other.actual_payload_length)
self.fin == other.fin and
self.rsv1 == other.rsv1 and
self.rsv2 == other.rsv2 and
self.rsv3 == other.rsv3 and
self.opcode == other.opcode and
self.mask_bit == other.mask_bit and
self.payload_length_code == other.payload_length_code and
self.masking_key == other.masking_key and
self.payload == other.payload and
self.decoded_payload == other.decoded_payload and
self.actual_payload_length == other.actual_payload_length
)
def apply_mask(message, masking_key):
"""

View File

@ -5,6 +5,7 @@ from netlib.websockets import websockets as ws
import os
from nose.tools import raises
class TestWebSockets(test.ServerTestBase):
handler = impl.WebSocketsEchoHandler
@ -22,9 +23,12 @@ class TestWebSockets(test.ServerTestBase):
self.echo("hello I'm the client")
def test_frame_sizes(self):
small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length
medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int
large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int
# length can fit in the the 7 bit payload length
small_msg = self.random_bytes(100)
# 50kb, sligthly larger than can fit in a 7 bit int
medium_msg = self.random_bytes(50000)
# 150kb, slightly larger than can fit in a 16 bit int
large_msg = self.random_bytes(150000)
self.echo(small_msg)
self.echo(medium_msg)
@ -33,51 +37,54 @@ class TestWebSockets(test.ServerTestBase):
def test_default_builder(self):
"""
default builder should always generate valid frames
"""
"""
msg = self.random_bytes()
client_frame = ws.WebSocketsFrame.default(msg, from_client = True)
client_frame = ws.Frame.default(msg, from_client = True)
assert client_frame.is_valid()
server_frame = ws.WebSocketsFrame.default(msg, from_client = False)
server_frame = ws.Frame.default(msg, from_client = False)
assert server_frame.is_valid()
def test_serialization_bijection(self):
"""
Ensure that various frame types can be serialized/deserialized back and forth
between to_bytes() and from_bytes()
"""
Ensure that various frame types can be serialized/deserialized back
and forth between to_bytes() and from_bytes()
"""
for is_client in [True, False]:
for num_bytes in [100, 50000, 150000]:
frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client)
assert frame == ws.WebSocketsFrame.from_bytes(frame.to_bytes())
for num_bytes in [100, 50000, 150000]:
frame = ws.Frame.default(
self.random_bytes(num_bytes), is_client
)
assert frame == ws.Frame.from_bytes(frame.to_bytes())
bytes = b'\x81\x11cba'
assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes
assert ws.Frame.from_bytes(bytes).to_bytes() == bytes
@raises(ws.WebSocketFrameValidationException)
def test_safe_to_bytes(self):
frame = ws.WebSocketsFrame.default(self.random_bytes(8))
frame.actual_payload_length = 1 #corrupt the frame
frame = ws.Frame.default(self.random_bytes(8))
frame.actual_payload_length = 1 # corrupt the frame
frame.safe_to_bytes()
class BadHandshakeHandler(impl.WebSocketsEchoHandler):
def handshake(self):
client_hs = ws.read_handshake(self.rfile.read, 1)
key = ws.process_handshake_from_client(client_hs)
response = ws.create_server_handshake("malformed_key")
ws.process_handshake_from_client(client_hs)
response = ws.create_server_handshake("malformed_key")
self.wfile.write(response)
self.wfile.flush()
self.handshake_done = True
class TestBadHandshake(test.ServerTestBase):
"""
Ensure that the client disconnects if the server handshake is malformed
"""
"""
handler = BadHandshakeHandler
@raises(tcp.NetLibDisconnect)
def test(self):
client = impl.WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message("hello")
client.send_message("hello")