websockets: extract frame header creation into a function

This commit is contained in:
Aldo Cortesi 2015-04-24 08:47:09 +12:00
parent 42a87a1d8b
commit bdd52fead3
2 changed files with 147 additions and 120 deletions

View File

@ -35,6 +35,139 @@ class OPCODE:
PONG = 0x0a
def apply_mask(message, masking_key):
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
This method both encodes and decodes strings with the provided mask
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
masks = [utils.bytes_to_int(byte) for byte in masking_key]
result = ""
for char in message:
result += chr(ord(char) ^ masks[len(result) % 4])
return result
def client_handshake_headers(key=None, version=VERSION):
"""
Create the headers for a valid HTTP upgrade request. If Key is not
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of ODictCaseless
"""
if not key:
key = base64.b64encode(os.urandom(16)).decode('utf-8')
return odict.ODictCaseless([
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Key', key),
('Sec-WebSocket-Version', version)
])
def server_handshake_headers(key):
"""
The server response is a valid HTTP 101 response.
"""
return odict.ODictCaseless(
[
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Accept', create_server_nonce(key))
]
)
def get_payload_length_pair(payload_bytestring):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
actual_length = len(payload_bytestring)
if actual_length <= 125:
length_code = actual_length
elif actual_length >= 126 and actual_length <= 65535:
length_code = 126
else:
length_code = 127
return (length_code, actual_length)
def make_length_code(len):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
if len <= 125:
return len
elif len >= 126 and len <= 65535:
return 126
else:
return 127
def check_client_handshake(headers):
if headers.get_first("upgrade", None) != "websocket":
return
return headers.get_first('sec-websocket-key')
def check_server_handshake(headers):
if headers.get_first("upgrade", None) != "websocket":
return
return headers.get_first('sec-websocket-accept')
def create_server_nonce(client_nonce):
return base64.b64encode(
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
)
def frame_header_bytes(
opcode = 0,
payload_length = 0,
fin = 0,
rsv1 = 0,
rsv2 = 0,
rsv3 = 0,
mask = 0,
masking_key = None,
length_code = None
):
first_byte = (fin << 7) | (rsv1 << 6) |\
(rsv2 << 4) | (rsv3 << 4) | opcode
if length_code is None:
length_code = make_length_code(payload_length)
second_byte = (mask << 7) | length_code
b = chr(first_byte) + chr(second_byte)
if payload_length < 126:
pass
elif payload_length < MAX_16_BIT_INT:
# '!H' pack as 16 bit unsigned short
# add 2 byte extended payload length
b += struct.pack('!H', payload_length)
elif payload_length < MAX_64_BIT_INT:
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
b += struct.pack('!Q', payload_length)
if masking_key is not None:
b += masking_key
return b
class Frame(object):
"""
Represents one websockets frame.
@ -170,43 +303,16 @@ class Frame(object):
If you haven't checked is_valid_frame() then there's no guarentees
that the serialized bytes will be correct. see safe_to_bytes()
"""
# break down of the bit-math used to construct the first byte from the
# frame's integer values first shift the significant bit into the
# correct position
# 00000001 << 7 = 10000000
# ...
# then combine:
#
# 10000000 fin
# 01000000 res1
# 00100000 res2
# 00010000 res3
# 00000001 opcode
# -------- OR
# 11110001 = first_byte
first_byte = (self.fin << 7) | (self.rsv1 << 6) |\
(self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode
second_byte = (self.mask_bit << 7) | self.payload_length_code
b = chr(first_byte) + chr(second_byte)
if self.actual_payload_length < 126:
pass
elif self.actual_payload_length < MAX_16_BIT_INT:
# '!H' pack as 16 bit unsigned short
# add 2 byte extended payload length
b += struct.pack('!H', self.actual_payload_length)
elif self.actual_payload_length < MAX_64_BIT_INT:
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
b += struct.pack('!Q', self.actual_payload_length)
if self.masking_key is not None:
b += self.masking_key
b = frame_header_bytes(
opcode = self.opcode,
fin = self.fin,
rsv1 = self.rsv1,
rsv2 = self.rsv2,
rsv3 = self.rsv3,
mask = self.mask_bit,
masking_key = self.masking_key,
payload_length = self.actual_payload_length
)
b += self.payload # already will be encoded if neccessary
return b
@ -283,86 +389,3 @@ class Frame(object):
self.decoded_payload == other.decoded_payload and
self.actual_payload_length == other.actual_payload_length
)
def apply_mask(message, masking_key):
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
This method both encodes and decodes strings with the provided mask
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
masks = [utils.bytes_to_int(byte) for byte in masking_key]
result = ""
for char in message:
result += chr(ord(char) ^ masks[len(result) % 4])
return result
def client_handshake_headers(key=None, version=VERSION):
"""
Create the headers for a valid HTTP upgrade request. If Key is not
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of ODictCaseless
"""
if not key:
key = base64.b64encode(os.urandom(16)).decode('utf-8')
return odict.ODictCaseless([
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Key', key),
('Sec-WebSocket-Version', version)
])
def server_handshake_headers(key):
"""
The server response is a valid HTTP 101 response.
"""
return odict.ODictCaseless(
[
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Accept', create_server_nonce(key))
]
)
def get_payload_length_pair(payload_bytestring):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
actual_length = len(payload_bytestring)
if actual_length <= 125:
length_code = actual_length
elif actual_length >= 126 and actual_length <= 65535:
length_code = 126
else:
length_code = 127
return (length_code, actual_length)
def check_client_handshake(headers):
if headers.get_first("upgrade", None) != "websocket":
return
return headers.get_first('sec-websocket-key')
def check_server_handshake(headers):
if headers.get_first("upgrade", None) != "websocket":
return
return headers.get_first('sec-websocket-accept')
def create_server_nonce(client_nonce):
return base64.b64encode(
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
)

View File

@ -3,6 +3,10 @@ import os
from nose.tools import raises
def test_frame_header_bytes():
assert websockets.frame_header_bytes()
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(