mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-29 19:08:44 +00:00
188 lines
5.9 KiB
Python
188 lines
5.9 KiB
Python
import os
|
|
import codecs
|
|
import pytest
|
|
|
|
from wsproto.frame_protocol import Opcode
|
|
|
|
from pathod.language import websockets_frame
|
|
from mitmproxy.test import tutils
|
|
|
|
|
|
class TestMasker:
|
|
|
|
@pytest.mark.parametrize("input,expected", [
|
|
([b"a"], '00'),
|
|
([b"four"], '070d1616'),
|
|
([b"fourf"], '070d161607'),
|
|
([b"fourfive"], '070d1616070b1501'),
|
|
([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'),
|
|
([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa
|
|
])
|
|
def test_masker(self, input, expected):
|
|
m = websockets_frame.Masker(b"abcd")
|
|
data = b"".join([m(t) for t in input])
|
|
assert data == codecs.decode(expected, 'hex')
|
|
|
|
data = websockets_frame.Masker(b"abcd")(data)
|
|
assert data == b"".join(input)
|
|
|
|
|
|
class TestFrameHeader:
|
|
|
|
@pytest.mark.parametrize("input,expected", [
|
|
(0, '0100'),
|
|
(125, '017D'),
|
|
(126, '017E007E'),
|
|
(127, '017E007F'),
|
|
(142, '017E008E'),
|
|
(65534, '017EFFFE'),
|
|
(65535, '017EFFFF'),
|
|
(65536, '017F0000000000010000'),
|
|
(8589934591, '017F00000001FFFFFFFF'),
|
|
(2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'),
|
|
])
|
|
def test_serialization_length(self, input, expected):
|
|
h = websockets_frame.FrameHeader(
|
|
opcode=Opcode.TEXT,
|
|
payload_length=input,
|
|
)
|
|
assert bytes(h) == codecs.decode(expected, 'hex')
|
|
|
|
def test_serialization_too_large(self):
|
|
h = websockets_frame.FrameHeader(
|
|
payload_length=2 ** 64 + 1,
|
|
)
|
|
with pytest.raises(ValueError):
|
|
bytes(h)
|
|
|
|
@pytest.mark.parametrize("input,expected", [
|
|
('0100', 0),
|
|
('017D', 125),
|
|
('017E007E', 126),
|
|
('017E007F', 127),
|
|
('017E008E', 142),
|
|
('017EFFFE', 65534),
|
|
('017EFFFF', 65535),
|
|
('017F0000000000010000', 65536),
|
|
('017F00000001FFFFFFFF', 8589934591),
|
|
('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1),
|
|
])
|
|
def test_deserialization_length(self, input, expected):
|
|
h = websockets_frame.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
|
|
assert h.payload_length == expected
|
|
|
|
@pytest.mark.parametrize("input,expected", [
|
|
('0100', (False, None)),
|
|
('018000000000', (True, '00000000')),
|
|
('018012345678', (True, '12345678')),
|
|
])
|
|
def test_deserialization_masking(self, input, expected):
|
|
h = websockets_frame.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
|
|
assert h.mask == expected[0]
|
|
if h.mask:
|
|
assert h.masking_key == codecs.decode(expected[1], 'hex')
|
|
|
|
def test_equality(self):
|
|
h = websockets_frame.FrameHeader(mask=True, masking_key=b'1234')
|
|
h2 = websockets_frame.FrameHeader(mask=True, masking_key=b'1234')
|
|
assert h == h2
|
|
|
|
h = websockets_frame.FrameHeader(fin=True)
|
|
h2 = websockets_frame.FrameHeader(fin=False)
|
|
assert h != h2
|
|
|
|
assert h != 'foobar'
|
|
|
|
def test_roundtrip(self):
|
|
def round(*args, **kwargs):
|
|
h = websockets_frame.FrameHeader(*args, **kwargs)
|
|
h2 = websockets_frame.FrameHeader.from_file(tutils.treader(bytes(h)))
|
|
assert h == h2
|
|
|
|
round()
|
|
round(fin=True)
|
|
round(rsv1=True)
|
|
round(rsv2=True)
|
|
round(rsv3=True)
|
|
round(payload_length=1)
|
|
round(payload_length=100)
|
|
round(payload_length=1000)
|
|
round(payload_length=10000)
|
|
round(opcode=Opcode.PING)
|
|
round(masking_key=b"test")
|
|
|
|
def test_human_readable(self):
|
|
f = websockets_frame.FrameHeader(
|
|
masking_key=b"test",
|
|
fin=True,
|
|
payload_length=10
|
|
)
|
|
assert repr(f)
|
|
|
|
f = websockets_frame.FrameHeader()
|
|
assert repr(f)
|
|
|
|
def test_funky(self):
|
|
f = websockets_frame.FrameHeader(masking_key=b"test", mask=False)
|
|
raw = bytes(f)
|
|
f2 = websockets_frame.FrameHeader.from_file(tutils.treader(raw))
|
|
assert not f2.mask
|
|
|
|
def test_violations(self):
|
|
with pytest.raises(Exception, match="opcode"):
|
|
websockets_frame.FrameHeader(opcode=17)
|
|
with pytest.raises(Exception, match="Masking key"):
|
|
websockets_frame.FrameHeader(masking_key=b"x")
|
|
|
|
def test_automask(self):
|
|
f = websockets_frame.FrameHeader(mask=True)
|
|
assert f.masking_key
|
|
|
|
f = websockets_frame.FrameHeader(masking_key=b"foob")
|
|
assert f.mask
|
|
|
|
f = websockets_frame.FrameHeader(masking_key=b"foob", mask=0)
|
|
assert not f.mask
|
|
assert not f.masking_key
|
|
|
|
|
|
class TestFrame:
|
|
def test_equality(self):
|
|
f = websockets_frame.Frame(payload=b'1234')
|
|
f2 = websockets_frame.Frame(payload=b'1234')
|
|
assert f == f2
|
|
|
|
assert f != b'1234'
|
|
|
|
def test_roundtrip(self):
|
|
def round(*args, **kwargs):
|
|
f = websockets_frame.Frame(*args, **kwargs)
|
|
raw = bytes(f)
|
|
f2 = websockets_frame.Frame.from_file(tutils.treader(raw))
|
|
assert f == f2
|
|
round(b"test")
|
|
round(b"test", fin=1)
|
|
round(b"test", rsv1=1)
|
|
round(b"test", opcode=Opcode.PING)
|
|
round(b"test", masking_key=b"test")
|
|
|
|
def test_human_readable(self):
|
|
f = websockets_frame.Frame()
|
|
assert repr(f)
|
|
|
|
f = websockets_frame.Frame(b"foobar")
|
|
assert "foobar" in repr(f)
|
|
|
|
@pytest.mark.parametrize("masked", [True, False])
|
|
@pytest.mark.parametrize("length", [100, 50000, 150000])
|
|
def test_serialization_bijection(self, masked, length):
|
|
frame = websockets_frame.Frame(
|
|
os.urandom(length),
|
|
fin=True,
|
|
opcode=Opcode.TEXT,
|
|
mask=int(masked),
|
|
masking_key=(os.urandom(4) if masked else None)
|
|
)
|
|
serialized = bytes(frame)
|
|
assert frame == websockets_frame.Frame.from_bytes(serialized)
|