mitmproxy/test/pathod/language/test_websockets_frame.py

188 lines
5.9 KiB
Python
Raw Normal View History

import os
import codecs
import pytest
from wsproto.frame_protocol import Opcode
from pathod.language import websockets_frame
from mitmproxy.test import tutils
class TestMasker:
@pytest.mark.parametrize("input,expected", [
([b"a"], '00'),
([b"four"], '070d1616'),
([b"fourf"], '070d161607'),
([b"fourfive"], '070d1616070b1501'),
([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'),
([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa
])
def test_masker(self, input, expected):
m = websockets_frame.Masker(b"abcd")
data = b"".join([m(t) for t in input])
assert data == codecs.decode(expected, 'hex')
data = websockets_frame.Masker(b"abcd")(data)
assert data == b"".join(input)
2016-10-17 04:29:45 +00:00
class TestFrameHeader:
@pytest.mark.parametrize("input,expected", [
(0, '0100'),
(125, '017D'),
(126, '017E007E'),
(127, '017E007F'),
(142, '017E008E'),
(65534, '017EFFFE'),
(65535, '017EFFFF'),
(65536, '017F0000000000010000'),
(8589934591, '017F00000001FFFFFFFF'),
(2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'),
])
def test_serialization_length(self, input, expected):
h = websockets_frame.FrameHeader(
opcode=Opcode.TEXT,
payload_length=input,
)
assert bytes(h) == codecs.decode(expected, 'hex')
def test_serialization_too_large(self):
h = websockets_frame.FrameHeader(
payload_length=2 ** 64 + 1,
)
with pytest.raises(ValueError):
bytes(h)
@pytest.mark.parametrize("input,expected", [
('0100', 0),
('017D', 125),
('017E007E', 126),
('017E007F', 127),
('017E008E', 142),
('017EFFFE', 65534),
('017EFFFF', 65535),
('017F0000000000010000', 65536),
('017F00000001FFFFFFFF', 8589934591),
('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1),
])
def test_deserialization_length(self, input, expected):
h = websockets_frame.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
assert h.payload_length == expected
@pytest.mark.parametrize("input,expected", [
('0100', (False, None)),
('018000000000', (True, '00000000')),
('018012345678', (True, '12345678')),
])
def test_deserialization_masking(self, input, expected):
h = websockets_frame.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
assert h.mask == expected[0]
if h.mask:
assert h.masking_key == codecs.decode(expected[1], 'hex')
def test_equality(self):
h = websockets_frame.FrameHeader(mask=True, masking_key=b'1234')
h2 = websockets_frame.FrameHeader(mask=True, masking_key=b'1234')
assert h == h2
h = websockets_frame.FrameHeader(fin=True)
h2 = websockets_frame.FrameHeader(fin=False)
assert h != h2
assert h != 'foobar'
def test_roundtrip(self):
def round(*args, **kwargs):
h = websockets_frame.FrameHeader(*args, **kwargs)
h2 = websockets_frame.FrameHeader.from_file(tutils.treader(bytes(h)))
assert h == h2
round()
round(fin=True)
round(rsv1=True)
round(rsv2=True)
round(rsv3=True)
round(payload_length=1)
round(payload_length=100)
round(payload_length=1000)
round(payload_length=10000)
round(opcode=Opcode.PING)
round(masking_key=b"test")
def test_human_readable(self):
f = websockets_frame.FrameHeader(
masking_key=b"test",
fin=True,
payload_length=10
)
assert repr(f)
f = websockets_frame.FrameHeader()
assert repr(f)
def test_funky(self):
f = websockets_frame.FrameHeader(masking_key=b"test", mask=False)
raw = bytes(f)
f2 = websockets_frame.FrameHeader.from_file(tutils.treader(raw))
assert not f2.mask
def test_violations(self):
2017-02-06 16:48:44 +00:00
with pytest.raises(Exception, match="opcode"):
websockets_frame.FrameHeader(opcode=17)
2017-02-06 16:48:44 +00:00
with pytest.raises(Exception, match="Masking key"):
websockets_frame.FrameHeader(masking_key=b"x")
def test_automask(self):
f = websockets_frame.FrameHeader(mask=True)
assert f.masking_key
f = websockets_frame.FrameHeader(masking_key=b"foob")
assert f.mask
f = websockets_frame.FrameHeader(masking_key=b"foob", mask=0)
assert not f.mask
2018-06-11 11:24:08 +00:00
assert not f.masking_key
2016-10-17 04:29:45 +00:00
class TestFrame:
def test_equality(self):
f = websockets_frame.Frame(payload=b'1234')
f2 = websockets_frame.Frame(payload=b'1234')
assert f == f2
assert f != b'1234'
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets_frame.Frame(*args, **kwargs)
raw = bytes(f)
f2 = websockets_frame.Frame.from_file(tutils.treader(raw))
assert f == f2
round(b"test")
round(b"test", fin=1)
round(b"test", rsv1=1)
round(b"test", opcode=Opcode.PING)
round(b"test", masking_key=b"test")
def test_human_readable(self):
f = websockets_frame.Frame()
assert repr(f)
f = websockets_frame.Frame(b"foobar")
assert "foobar" in repr(f)
@pytest.mark.parametrize("masked", [True, False])
@pytest.mark.parametrize("length", [100, 50000, 150000])
def test_serialization_bijection(self, masked, length):
frame = websockets_frame.Frame(
os.urandom(length),
fin=True,
opcode=Opcode.TEXT,
mask=int(masked),
masking_key=(os.urandom(4) if masked else None)
)
serialized = bytes(frame)
assert frame == websockets_frame.Frame.from_bytes(serialized)