websockets: A progressive masker.

This commit is contained in:
Aldo Cortesi 2015-05-01 10:09:35 +12:00
parent 4dce7ee074
commit 7d9e38ffb1
2 changed files with 34 additions and 14 deletions

View File

@ -35,20 +35,24 @@ OPCODE = utils.BiDi(
) )
def apply_mask(message, masking_key): class Masker:
""" """
Data sent from the server must be masked to prevent malicious clients Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns from sending data over the wire in predictable patterns
This method both encodes and decodes strings with the provided mask
Servers do not have to mask data they send to the client. Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3 https://tools.ietf.org/html/rfc6455#section-5.3
""" """
masks = [utils.bytes_to_int(byte) for byte in masking_key] def __init__(self, key):
self.key = key
self.masks = [utils.bytes_to_int(byte) for byte in key]
self.offset = 0
def __call__(self, data):
result = "" result = ""
for char in message: for c in data:
result += chr(ord(char) ^ masks[len(result) % 4]) result += chr(ord(c) ^ self.masks[self.offset % 4])
self.offset += 1
return result return result
@ -324,7 +328,7 @@ class Frame(object):
""" """
b = self.header.to_bytes() b = self.header.to_bytes()
if self.header.masking_key: if self.header.masking_key:
b += apply_mask(self.payload, self.header.masking_key) b += Masker(self.header.masking_key)(self.payload)
else: else:
b += self.payload b += self.payload
return b return b
@ -345,7 +349,7 @@ class Frame(object):
payload = fp.read(header.payload_length) payload = fp.read(header.payload_length)
if header.mask == 1 and header.masking_key: if header.mask == 1 and header.masking_key:
payload = apply_mask(payload, header.masking_key) payload = Masker(header.masking_key)(payload)
return cls( return cls(
payload, payload,

View File

@ -232,3 +232,19 @@ class TestFrame:
def test_human_readable(self): def test_human_readable(self):
f = websockets.Frame() f = websockets.Frame()
assert f.human_readable() assert f.human_readable()
def test_masker():
tests = [
["a"],
["four"],
["fourf"],
["fourfive"],
["a", "aasdfasdfa", "asdf"],
["a"*50, "aasdfasdfa", "asdf"],
]
for i in tests:
m = websockets.Masker("abcd")
data = "".join([m(t) for t in i])
data2 = websockets.Masker("abcd")(data)
assert data2 == "".join(i)