import os import codecs import pytest from wsproto.frame_protocol import Opcode from pathod.language import websockets_frame from mitmproxy.test import tutils class TestMasker: @pytest.mark.parametrize("input,expected", [ ([b"a"], '00'), ([b"four"], '070d1616'), ([b"fourf"], '070d161607'), ([b"fourfive"], '070d1616070b1501'), ([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'), ([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa ]) def test_masker(self, input, expected): m = websockets_frame.Masker(b"abcd") data = b"".join([m(t) for t in input]) assert data == codecs.decode(expected, 'hex') data = websockets_frame.Masker(b"abcd")(data) assert data == b"".join(input) class TestFrameHeader: @pytest.mark.parametrize("input,expected", [ (0, '0100'), (125, '017D'), (126, '017E007E'), (127, '017E007F'), (142, '017E008E'), (65534, '017EFFFE'), (65535, '017EFFFF'), (65536, '017F0000000000010000'), (8589934591, '017F00000001FFFFFFFF'), (2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'), ]) def test_serialization_length(self, input, expected): h = websockets_frame.FrameHeader( opcode=Opcode.TEXT, payload_length=input, ) assert bytes(h) == codecs.decode(expected, 'hex') def test_serialization_too_large(self): h = websockets_frame.FrameHeader( payload_length=2 ** 64 + 1, ) with pytest.raises(ValueError): bytes(h) @pytest.mark.parametrize("input,expected", [ ('0100', 0), ('017D', 125), ('017E007E', 126), ('017E007F', 127), ('017E008E', 142), ('017EFFFE', 65534), ('017EFFFF', 65535), ('017F0000000000010000', 65536), ('017F00000001FFFFFFFF', 8589934591), ('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1), ]) def test_deserialization_length(self, input, expected): h = websockets_frame.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) assert h.payload_length == expected @pytest.mark.parametrize("input,expected", [ ('0100', (False, None)), ('018000000000', (True, '00000000')), ('018012345678', (True, '12345678')), ]) def test_deserialization_masking(self, input, expected): h = websockets_frame.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) assert h.mask == expected[0] if h.mask: assert h.masking_key == codecs.decode(expected[1], 'hex') def test_equality(self): h = websockets_frame.FrameHeader(mask=True, masking_key=b'1234') h2 = websockets_frame.FrameHeader(mask=True, masking_key=b'1234') assert h == h2 h = websockets_frame.FrameHeader(fin=True) h2 = websockets_frame.FrameHeader(fin=False) assert h != h2 assert h != 'foobar' def test_roundtrip(self): def round(*args, **kwargs): h = websockets_frame.FrameHeader(*args, **kwargs) h2 = websockets_frame.FrameHeader.from_file(tutils.treader(bytes(h))) assert h == h2 round() round(fin=True) round(rsv1=True) round(rsv2=True) round(rsv3=True) round(payload_length=1) round(payload_length=100) round(payload_length=1000) round(payload_length=10000) round(opcode=Opcode.PING) round(masking_key=b"test") def test_human_readable(self): f = websockets_frame.FrameHeader( masking_key=b"test", fin=True, payload_length=10 ) assert repr(f) f = websockets_frame.FrameHeader() assert repr(f) def test_funky(self): f = websockets_frame.FrameHeader(masking_key=b"test", mask=False) raw = bytes(f) f2 = websockets_frame.FrameHeader.from_file(tutils.treader(raw)) assert not f2.mask def test_violations(self): with pytest.raises(Exception, match="opcode"): websockets_frame.FrameHeader(opcode=17) with pytest.raises(Exception, match="Masking key"): websockets_frame.FrameHeader(masking_key=b"x") def test_automask(self): f = websockets_frame.FrameHeader(mask=True) assert f.masking_key f = websockets_frame.FrameHeader(masking_key=b"foob") assert f.mask f = websockets_frame.FrameHeader(masking_key=b"foob", mask=0) assert not f.mask assert not f.masking_key class TestFrame: def test_equality(self): f = websockets_frame.Frame(payload=b'1234') f2 = websockets_frame.Frame(payload=b'1234') assert f == f2 assert f != b'1234' def test_roundtrip(self): def round(*args, **kwargs): f = websockets_frame.Frame(*args, **kwargs) raw = bytes(f) f2 = websockets_frame.Frame.from_file(tutils.treader(raw)) assert f == f2 round(b"test") round(b"test", fin=1) round(b"test", rsv1=1) round(b"test", opcode=Opcode.PING) round(b"test", masking_key=b"test") def test_human_readable(self): f = websockets_frame.Frame() assert repr(f) f = websockets_frame.Frame(b"foobar") assert "foobar" in repr(f) @pytest.mark.parametrize("masked", [True, False]) @pytest.mark.parametrize("length", [100, 50000, 150000]) def test_serialization_bijection(self, masked, length): frame = websockets_frame.Frame( os.urandom(length), fin=True, opcode=Opcode.TEXT, mask=int(masked), masking_key=(os.urandom(4) if masked else None) ) serialized = bytes(frame) assert frame == websockets_frame.Frame.from_bytes(serialized)