websockets: refactor to avoid rundantly specifying payloads and payload lengths

This commit is contained in:
Aldo Cortesi 2015-04-24 09:21:04 +12:00
parent bdd52fead3
commit 3519871f34
2 changed files with 26 additions and 51 deletions

View File

@ -198,15 +198,13 @@ class Frame(object):
self,
fin, # decmial integer 1 or 0
opcode, # decmial integer 1 - 4
mask_bit, # decimal integer 1 or 0
payload_length_code, # decimal integer 1 - 127
decoded_payload, # bytestring
payload = "", # bytestring
masking_key = None, # 32 bit byte string
mask_bit = 0, # decimal integer 1 or 0
payload_length_code = None, # decimal integer 1 - 127
rsv1 = 0, # decimal integer 1 or 0
rsv2 = 0, # decimal integer 1 or 0
rsv3 = 0, # decimal integer 1 or 0
payload = None, # bytestring
masking_key = None, # 32 bit byte string
actual_payload_length = None, # any decimal integer
):
self.fin = fin
self.rsv1 = rsv1
@ -217,8 +215,6 @@ class Frame(object):
self.payload_length_code = payload_length_code
self.masking_key = masking_key
self.payload = payload
self.decoded_payload = decoded_payload
self.actual_payload_length = actual_payload_length
@classmethod
def default(cls, message, from_client = False):
@ -226,27 +222,19 @@ class Frame(object):
Construct a basic websocket frame from some default values.
Creates a non-fragmented text frame.
"""
length_code, actual_length = get_payload_length_pair(message)
if from_client:
mask_bit = 1
# Random masking key
masking_key = os.urandom(4)
payload = apply_mask(message, masking_key)
else:
mask_bit = 0
masking_key = None
payload = message
return cls(
fin = 1, # final frame
opcode = OPCODE.TEXT, # text
mask_bit = mask_bit,
payload_length_code = length_code,
payload = payload,
payload = message,
masking_key = masking_key,
decoded_payload = message,
actual_payload_length = actual_length
)
def is_valid(self):
@ -261,17 +249,12 @@ class Frame(object):
0 <= self.rsv3 <= 1,
1 <= self.opcode <= 4,
0 <= self.mask_bit <= 1,
1 <= self.payload_length_code <= 127,
self.actual_payload_length == len(self.payload),
#1 <= self.payload_length_code <= 127,
1 <= len(self.masking_key) <= 4 if self.mask_bit else True,
self.masking_key is not None if self.mask_bit else True
]
if not all(constraints):
return False
elif self.payload and self.masking_key:
decoded = apply_mask(self.payload, self.masking_key)
if decoded != self.decoded_payload:
return False
return True
def human_readable(self): # pragma: nocover
@ -285,8 +268,6 @@ class Frame(object):
("payload_length_code - " + str(self.payload_length_code)),
("masking_key - " + repr(str(self.masking_key))),
("payload - " + repr(str(self.payload))),
("decoded_payload - " + repr(str(self.decoded_payload))),
("actual_payload_length - " + str(self.actual_payload_length))
])
@classmethod
@ -311,9 +292,12 @@ class Frame(object):
rsv3 = self.rsv3,
mask = self.mask_bit,
masking_key = self.masking_key,
payload_length = self.actual_payload_length
payload_length = len(self.payload) if self.payload else 0
)
b += self.payload # already will be encoded if neccessary
if self.masking_key:
b += apply_mask(self.payload, self.masking_key)
else:
b += self.payload
return b
def to_file(self, writer):
@ -359,10 +343,8 @@ class Frame(object):
payload = fp.read(actual_payload_length)
if mask_bit == 1:
decoded_payload = apply_mask(payload, masking_key)
else:
decoded_payload = payload
if mask_bit == 1 and masking_key:
payload = apply_mask(payload, masking_key)
return cls(
fin = fin,
@ -371,11 +353,17 @@ class Frame(object):
payload_length_code = payload_length,
payload = payload,
masking_key = masking_key,
decoded_payload = decoded_payload,
actual_payload_length = actual_payload_length
)
def __eq__(self, other):
if self.payload_length_code is None:
myplc = make_length_code(len(self.payload))
else:
myplc = self.payload_length_code
if other.payload_length_code is None:
otherplc = make_length_code(len(other.payload))
else:
otherplc = other.payload_length_code
return (
self.fin == other.fin and
self.rsv1 == other.rsv1 and
@ -383,9 +371,7 @@ class Frame(object):
self.rsv3 == other.rsv3 and
self.opcode == other.opcode and
self.mask_bit == other.mask_bit and
self.payload_length_code == other.payload_length_code and
self.masking_key == other.masking_key and
self.payload == other.payload and
self.decoded_payload == other.decoded_payload and
self.actual_payload_length == other.actual_payload_length
self.payload == other.payload,
myplc == otherplc
)

View File

@ -23,7 +23,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
def read_next_message(self):
frame = websockets.Frame.from_file(self.rfile)
self.on_message(frame.decoded_payload)
self.on_message(frame.payload)
def send_message(self, message):
frame = websockets.Frame.default(message, from_client = False)
@ -107,7 +107,6 @@ class TestWebSockets(test.ServerTestBase):
"""
msg = self.random_bytes()
client_frame = websockets.Frame.default(msg, from_client = True)
assert client_frame.is_valid()
server_frame = websockets.Frame.default(msg, from_client = False)
assert server_frame.is_valid()
@ -128,17 +127,6 @@ class TestWebSockets(test.ServerTestBase):
frame.masking_key = "foobbarboo"
assert not frame.is_valid()
frame = f()
frame.mask_bit = 0
frame.masking_key = "foob"
assert not frame.is_valid()
frame = f()
frame.masking_key = "foob"
frame.decoded_payload = "xxxx"
assert not frame.is_valid()
def test_serialization_bijection(self):
"""
Ensure that various frame types can be serialized/deserialized back
@ -149,9 +137,10 @@ class TestWebSockets(test.ServerTestBase):
frame = websockets.Frame.default(
self.random_bytes(num_bytes), is_client
)
assert frame == websockets.Frame.from_bytes(
frame2 = websockets.Frame.from_bytes(
frame.to_bytes()
)
assert frame == frame2
bytes = b'\x81\x03cba'
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes