mitmproxy/test/netlib/websockets/test_frame.py

165 lines
5.0 KiB
Python
Raw Normal View History

import os
import codecs
import pytest
from netlib import websockets
from netlib import tutils
2016-10-17 04:29:45 +00:00
class TestFrameHeader:
@pytest.mark.parametrize("input,expected", [
(0, '0100'),
(125, '017D'),
(126, '017E007E'),
(127, '017E007F'),
(142, '017E008E'),
(65534, '017EFFFE'),
(65535, '017EFFFF'),
(65536, '017F0000000000010000'),
(8589934591, '017F00000001FFFFFFFF'),
(2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'),
])
def test_serialization_length(self, input, expected):
h = websockets.FrameHeader(
opcode=websockets.OPCODE.TEXT,
payload_length=input,
)
assert bytes(h) == codecs.decode(expected, 'hex')
def test_serialization_too_large(self):
h = websockets.FrameHeader(
payload_length=2 ** 64 + 1,
)
with pytest.raises(ValueError):
bytes(h)
@pytest.mark.parametrize("input,expected", [
('0100', 0),
('017D', 125),
('017E007E', 126),
('017E007F', 127),
('017E008E', 142),
('017EFFFE', 65534),
('017EFFFF', 65535),
('017F0000000000010000', 65536),
('017F00000001FFFFFFFF', 8589934591),
('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1),
])
def test_deserialization_length(self, input, expected):
h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
assert h.payload_length == expected
@pytest.mark.parametrize("input,expected", [
('0100', (False, None)),
('018000000000', (True, '00000000')),
('018012345678', (True, '12345678')),
])
def test_deserialization_masking(self, input, expected):
h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
assert h.mask == expected[0]
if h.mask:
assert h.masking_key == codecs.decode(expected[1], 'hex')
def test_equality(self):
h = websockets.FrameHeader(mask=True, masking_key=b'1234')
h2 = websockets.FrameHeader(mask=True, masking_key=b'1234')
assert h == h2
h = websockets.FrameHeader(fin=True)
h2 = websockets.FrameHeader(fin=False)
assert h != h2
assert h != 'foobar'
def test_roundtrip(self):
def round(*args, **kwargs):
h = websockets.FrameHeader(*args, **kwargs)
h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h)))
assert h == h2
round()
round(fin=True)
round(rsv1=True)
round(rsv2=True)
round(rsv3=True)
round(payload_length=1)
round(payload_length=100)
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
round(masking_key=b"test")
def test_human_readable(self):
f = websockets.FrameHeader(
masking_key=b"test",
fin=True,
payload_length=10
)
assert repr(f)
f = websockets.FrameHeader()
assert repr(f)
def test_funky(self):
f = websockets.FrameHeader(masking_key=b"test", mask=False)
raw = bytes(f)
f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
assert not f2.mask
def test_violations(self):
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
def test_automask(self):
f = websockets.FrameHeader(mask=True)
assert f.masking_key
f = websockets.FrameHeader(masking_key=b"foob")
assert f.mask
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
assert not f.mask
assert f.masking_key
2016-10-17 04:29:45 +00:00
class TestFrame:
def test_equality(self):
f = websockets.Frame(payload=b'1234')
f2 = websockets.Frame(payload=b'1234')
assert f == f2
assert f != b'1234'
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
raw = bytes(f)
f2 = websockets.Frame.from_file(tutils.treader(raw))
assert f == f2
round(b"test")
round(b"test", fin=1)
round(b"test", rsv1=1)
round(b"test", opcode=websockets.OPCODE.PING)
round(b"test", masking_key=b"test")
def test_human_readable(self):
f = websockets.Frame()
assert repr(f)
f = websockets.Frame(b"foobar")
assert "foobar" in repr(f)
@pytest.mark.parametrize("masked", [True, False])
@pytest.mark.parametrize("length", [100, 50000, 150000])
def test_serialization_bijection(self, masked, length):
frame = websockets.Frame(
os.urandom(length),
fin=True,
opcode=websockets.OPCODE.TEXT,
mask=int(masked),
masking_key=(os.urandom(4) if masked else None)
)
serialized = bytes(frame)
assert frame == websockets.Frame.from_bytes(serialized)