websockets: A progressive masker.

This commit is contained in:
Aldo Cortesi 2015-05-01 10:09:35 +12:00
parent 4dce7ee074
commit 7d9e38ffb1
2 changed files with 34 additions and 14 deletions

View File

@ -35,21 +35,25 @@ OPCODE = utils.BiDi(
)
def apply_mask(message, masking_key):
class Masker:
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
This method both encodes and decodes strings with the provided mask
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
masks = [utils.bytes_to_int(byte) for byte in masking_key]
result = ""
for char in message:
result += chr(ord(char) ^ masks[len(result) % 4])
return result
def __init__(self, key):
self.key = key
self.masks = [utils.bytes_to_int(byte) for byte in key]
self.offset = 0
def __call__(self, data):
result = ""
for c in data:
result += chr(ord(c) ^ self.masks[self.offset % 4])
self.offset += 1
return result
def client_handshake_headers(key=None, version=VERSION):
@ -324,7 +328,7 @@ class Frame(object):
"""
b = self.header.to_bytes()
if self.header.masking_key:
b += apply_mask(self.payload, self.header.masking_key)
b += Masker(self.header.masking_key)(self.payload)
else:
b += self.payload
return b
@ -345,7 +349,7 @@ class Frame(object):
payload = fp.read(header.payload_length)
if header.mask == 1 and header.masking_key:
payload = apply_mask(payload, header.masking_key)
payload = Masker(header.masking_key)(payload)
return cls(
payload,

View File

@ -232,3 +232,19 @@ class TestFrame:
def test_human_readable(self):
f = websockets.Frame()
assert f.human_readable()
def test_masker():
tests = [
["a"],
["four"],
["fourf"],
["fourfive"],
["a", "aasdfasdfa", "asdf"],
["a"*50, "aasdfasdfa", "asdf"],
]
for i in tests:
m = websockets.Masker("abcd")
data = "".join([m(t) for t in i])
data2 = websockets.Masker("abcd")(data)
assert data2 == "".join(i)