mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-31 07:18:58 +00:00
websockets: extract frame header creation into a function
This commit is contained in:
parent
42a87a1d8b
commit
bdd52fead3
@ -35,6 +35,139 @@ class OPCODE:
|
||||
PONG = 0x0a
|
||||
|
||||
|
||||
def apply_mask(message, masking_key):
|
||||
"""
|
||||
Data sent from the server must be masked to prevent malicious clients
|
||||
from sending data over the wire in predictable patterns
|
||||
|
||||
This method both encodes and decodes strings with the provided mask
|
||||
|
||||
Servers do not have to mask data they send to the client.
|
||||
https://tools.ietf.org/html/rfc6455#section-5.3
|
||||
"""
|
||||
masks = [utils.bytes_to_int(byte) for byte in masking_key]
|
||||
result = ""
|
||||
for char in message:
|
||||
result += chr(ord(char) ^ masks[len(result) % 4])
|
||||
return result
|
||||
|
||||
|
||||
def client_handshake_headers(key=None, version=VERSION):
|
||||
"""
|
||||
Create the headers for a valid HTTP upgrade request. If Key is not
|
||||
specified, it is generated, and can be found in sec-websocket-key in
|
||||
the returned header set.
|
||||
|
||||
Returns an instance of ODictCaseless
|
||||
"""
|
||||
if not key:
|
||||
key = base64.b64encode(os.urandom(16)).decode('utf-8')
|
||||
return odict.ODictCaseless([
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
('Sec-WebSocket-Key', key),
|
||||
('Sec-WebSocket-Version', version)
|
||||
])
|
||||
|
||||
|
||||
def server_handshake_headers(key):
|
||||
"""
|
||||
The server response is a valid HTTP 101 response.
|
||||
"""
|
||||
return odict.ODictCaseless(
|
||||
[
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
('Sec-WebSocket-Accept', create_server_nonce(key))
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_payload_length_pair(payload_bytestring):
|
||||
"""
|
||||
A websockets frame contains an initial length_code, and an optional
|
||||
extended length code to represent the actual length if length code is
|
||||
larger than 125
|
||||
"""
|
||||
actual_length = len(payload_bytestring)
|
||||
|
||||
if actual_length <= 125:
|
||||
length_code = actual_length
|
||||
elif actual_length >= 126 and actual_length <= 65535:
|
||||
length_code = 126
|
||||
else:
|
||||
length_code = 127
|
||||
return (length_code, actual_length)
|
||||
|
||||
|
||||
def make_length_code(len):
|
||||
"""
|
||||
A websockets frame contains an initial length_code, and an optional
|
||||
extended length code to represent the actual length if length code is
|
||||
larger than 125
|
||||
"""
|
||||
if len <= 125:
|
||||
return len
|
||||
elif len >= 126 and len <= 65535:
|
||||
return 126
|
||||
else:
|
||||
return 127
|
||||
|
||||
|
||||
def check_client_handshake(headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return headers.get_first('sec-websocket-key')
|
||||
|
||||
|
||||
def check_server_handshake(headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return headers.get_first('sec-websocket-accept')
|
||||
|
||||
|
||||
def create_server_nonce(client_nonce):
|
||||
return base64.b64encode(
|
||||
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
|
||||
)
|
||||
|
||||
|
||||
def frame_header_bytes(
|
||||
opcode = 0,
|
||||
payload_length = 0,
|
||||
fin = 0,
|
||||
rsv1 = 0,
|
||||
rsv2 = 0,
|
||||
rsv3 = 0,
|
||||
mask = 0,
|
||||
masking_key = None,
|
||||
length_code = None
|
||||
):
|
||||
first_byte = (fin << 7) | (rsv1 << 6) |\
|
||||
(rsv2 << 4) | (rsv3 << 4) | opcode
|
||||
|
||||
if length_code is None:
|
||||
length_code = make_length_code(payload_length)
|
||||
|
||||
second_byte = (mask << 7) | length_code
|
||||
|
||||
b = chr(first_byte) + chr(second_byte)
|
||||
|
||||
if payload_length < 126:
|
||||
pass
|
||||
elif payload_length < MAX_16_BIT_INT:
|
||||
# '!H' pack as 16 bit unsigned short
|
||||
# add 2 byte extended payload length
|
||||
b += struct.pack('!H', payload_length)
|
||||
elif payload_length < MAX_64_BIT_INT:
|
||||
# '!Q' = pack as 64 bit unsigned long long
|
||||
# add 8 bytes extended payload length
|
||||
b += struct.pack('!Q', payload_length)
|
||||
if masking_key is not None:
|
||||
b += masking_key
|
||||
return b
|
||||
|
||||
|
||||
class Frame(object):
|
||||
"""
|
||||
Represents one websockets frame.
|
||||
@ -170,43 +303,16 @@ class Frame(object):
|
||||
If you haven't checked is_valid_frame() then there's no guarentees
|
||||
that the serialized bytes will be correct. see safe_to_bytes()
|
||||
"""
|
||||
|
||||
# break down of the bit-math used to construct the first byte from the
|
||||
# frame's integer values first shift the significant bit into the
|
||||
# correct position
|
||||
# 00000001 << 7 = 10000000
|
||||
# ...
|
||||
# then combine:
|
||||
#
|
||||
# 10000000 fin
|
||||
# 01000000 res1
|
||||
# 00100000 res2
|
||||
# 00010000 res3
|
||||
# 00000001 opcode
|
||||
# -------- OR
|
||||
# 11110001 = first_byte
|
||||
|
||||
first_byte = (self.fin << 7) | (self.rsv1 << 6) |\
|
||||
(self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode
|
||||
|
||||
second_byte = (self.mask_bit << 7) | self.payload_length_code
|
||||
|
||||
b = chr(first_byte) + chr(second_byte)
|
||||
|
||||
if self.actual_payload_length < 126:
|
||||
pass
|
||||
elif self.actual_payload_length < MAX_16_BIT_INT:
|
||||
# '!H' pack as 16 bit unsigned short
|
||||
# add 2 byte extended payload length
|
||||
b += struct.pack('!H', self.actual_payload_length)
|
||||
elif self.actual_payload_length < MAX_64_BIT_INT:
|
||||
# '!Q' = pack as 64 bit unsigned long long
|
||||
# add 8 bytes extended payload length
|
||||
b += struct.pack('!Q', self.actual_payload_length)
|
||||
|
||||
if self.masking_key is not None:
|
||||
b += self.masking_key
|
||||
|
||||
b = frame_header_bytes(
|
||||
opcode = self.opcode,
|
||||
fin = self.fin,
|
||||
rsv1 = self.rsv1,
|
||||
rsv2 = self.rsv2,
|
||||
rsv3 = self.rsv3,
|
||||
mask = self.mask_bit,
|
||||
masking_key = self.masking_key,
|
||||
payload_length = self.actual_payload_length
|
||||
)
|
||||
b += self.payload # already will be encoded if neccessary
|
||||
return b
|
||||
|
||||
@ -283,86 +389,3 @@ class Frame(object):
|
||||
self.decoded_payload == other.decoded_payload and
|
||||
self.actual_payload_length == other.actual_payload_length
|
||||
)
|
||||
|
||||
|
||||
def apply_mask(message, masking_key):
|
||||
"""
|
||||
Data sent from the server must be masked to prevent malicious clients
|
||||
from sending data over the wire in predictable patterns
|
||||
|
||||
This method both encodes and decodes strings with the provided mask
|
||||
|
||||
Servers do not have to mask data they send to the client.
|
||||
https://tools.ietf.org/html/rfc6455#section-5.3
|
||||
"""
|
||||
masks = [utils.bytes_to_int(byte) for byte in masking_key]
|
||||
result = ""
|
||||
for char in message:
|
||||
result += chr(ord(char) ^ masks[len(result) % 4])
|
||||
return result
|
||||
|
||||
|
||||
def client_handshake_headers(key=None, version=VERSION):
|
||||
"""
|
||||
Create the headers for a valid HTTP upgrade request. If Key is not
|
||||
specified, it is generated, and can be found in sec-websocket-key in
|
||||
the returned header set.
|
||||
|
||||
Returns an instance of ODictCaseless
|
||||
"""
|
||||
if not key:
|
||||
key = base64.b64encode(os.urandom(16)).decode('utf-8')
|
||||
return odict.ODictCaseless([
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
('Sec-WebSocket-Key', key),
|
||||
('Sec-WebSocket-Version', version)
|
||||
])
|
||||
|
||||
|
||||
def server_handshake_headers(key):
|
||||
"""
|
||||
The server response is a valid HTTP 101 response.
|
||||
"""
|
||||
return odict.ODictCaseless(
|
||||
[
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
('Sec-WebSocket-Accept', create_server_nonce(key))
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_payload_length_pair(payload_bytestring):
|
||||
"""
|
||||
A websockets frame contains an initial length_code, and an optional
|
||||
extended length code to represent the actual length if length code is
|
||||
larger than 125
|
||||
"""
|
||||
actual_length = len(payload_bytestring)
|
||||
|
||||
if actual_length <= 125:
|
||||
length_code = actual_length
|
||||
elif actual_length >= 126 and actual_length <= 65535:
|
||||
length_code = 126
|
||||
else:
|
||||
length_code = 127
|
||||
return (length_code, actual_length)
|
||||
|
||||
|
||||
def check_client_handshake(headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return headers.get_first('sec-websocket-key')
|
||||
|
||||
|
||||
def check_server_handshake(headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return headers.get_first('sec-websocket-accept')
|
||||
|
||||
|
||||
def create_server_nonce(client_nonce):
|
||||
return base64.b64encode(
|
||||
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
|
||||
)
|
||||
|
@ -3,6 +3,10 @@ import os
|
||||
from nose.tools import raises
|
||||
|
||||
|
||||
def test_frame_header_bytes():
|
||||
assert websockets.frame_header_bytes()
|
||||
|
||||
|
||||
class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
def __init__(self, connection, address, server):
|
||||
super(WebSocketsEchoHandler, self).__init__(
|
||||
|
Loading…
Reference in New Issue
Block a user