mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
websocket: interface refactoring
- Separate out FrameHeader. We need to deal with this separately in many circumstances. - Simpler equality scheme. - Bits are now specified by truthiness - we don't care about the integer value. This means lots of validation is not needed any more.
This commit is contained in:
parent
3519871f34
commit
f22bc0b4c7
@ -49,3 +49,19 @@ def hexdump(s):
|
||||
(o, x, cleanBin(part, True))
|
||||
)
|
||||
return parts
|
||||
|
||||
|
||||
def setbit(byte, offset, value):
|
||||
"""
|
||||
Set a bit in a byte to 1 if value is truthy, 0 if not.
|
||||
"""
|
||||
if value:
|
||||
return byte | (1 << offset)
|
||||
else:
|
||||
return byte & ~(1 << offset)
|
||||
|
||||
|
||||
def getbit(byte, offset):
|
||||
mask = 1 << offset
|
||||
if byte & mask:
|
||||
return True
|
||||
|
@ -1,5 +1,4 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
@ -83,23 +82,6 @@ def server_handshake_headers(key):
|
||||
)
|
||||
|
||||
|
||||
def get_payload_length_pair(payload_bytestring):
|
||||
"""
|
||||
A websockets frame contains an initial length_code, and an optional
|
||||
extended length code to represent the actual length if length code is
|
||||
larger than 125
|
||||
"""
|
||||
actual_length = len(payload_bytestring)
|
||||
|
||||
if actual_length <= 125:
|
||||
length_code = actual_length
|
||||
elif actual_length >= 126 and actual_length <= 65535:
|
||||
length_code = 126
|
||||
else:
|
||||
length_code = 127
|
||||
return (length_code, actual_length)
|
||||
|
||||
|
||||
def make_length_code(len):
|
||||
"""
|
||||
A websockets frame contains an initial length_code, and an optional
|
||||
@ -132,40 +114,113 @@ def create_server_nonce(client_nonce):
|
||||
)
|
||||
|
||||
|
||||
def frame_header_bytes(
|
||||
opcode = 0,
|
||||
payload_length = 0,
|
||||
fin = 0,
|
||||
rsv1 = 0,
|
||||
rsv2 = 0,
|
||||
rsv3 = 0,
|
||||
mask = 0,
|
||||
masking_key = None,
|
||||
length_code = None
|
||||
):
|
||||
first_byte = (fin << 7) | (rsv1 << 6) |\
|
||||
(rsv2 << 4) | (rsv3 << 4) | opcode
|
||||
DEFAULT = object()
|
||||
class FrameHeader:
|
||||
def __init__(
|
||||
self,
|
||||
opcode = OPCODE.TEXT,
|
||||
payload_length = 0,
|
||||
fin = False,
|
||||
rsv1 = False,
|
||||
rsv2 = False,
|
||||
rsv3 = False,
|
||||
masking_key = None,
|
||||
mask = DEFAULT,
|
||||
length_code = DEFAULT
|
||||
):
|
||||
self.opcode = opcode
|
||||
self.payload_length = payload_length
|
||||
self.fin = fin
|
||||
self.rsv1 = rsv1
|
||||
self.rsv2 = rsv2
|
||||
self.rsv3 = rsv3
|
||||
self.mask = mask
|
||||
self.masking_key = masking_key
|
||||
self.length_code = length_code
|
||||
|
||||
if length_code is None:
|
||||
length_code = make_length_code(payload_length)
|
||||
def to_bytes(self):
|
||||
first_byte = utils.setbit(0, 7, self.fin)
|
||||
first_byte = utils.setbit(first_byte, 6, self.rsv1)
|
||||
first_byte = utils.setbit(first_byte, 5, self.rsv2)
|
||||
first_byte = utils.setbit(first_byte, 4, self.rsv3)
|
||||
first_byte = first_byte | self.opcode
|
||||
|
||||
second_byte = (mask << 7) | length_code
|
||||
if self.length_code is DEFAULT:
|
||||
length_code = make_length_code(self.payload_length)
|
||||
else:
|
||||
length_code = self.length_code
|
||||
|
||||
b = chr(first_byte) + chr(second_byte)
|
||||
if self.mask is DEFAULT:
|
||||
mask = bool(self.masking_key)
|
||||
else:
|
||||
mask = self.mask
|
||||
|
||||
if payload_length < 126:
|
||||
pass
|
||||
elif payload_length < MAX_16_BIT_INT:
|
||||
# '!H' pack as 16 bit unsigned short
|
||||
# add 2 byte extended payload length
|
||||
b += struct.pack('!H', payload_length)
|
||||
elif payload_length < MAX_64_BIT_INT:
|
||||
# '!Q' = pack as 64 bit unsigned long long
|
||||
# add 8 bytes extended payload length
|
||||
b += struct.pack('!Q', payload_length)
|
||||
if masking_key is not None:
|
||||
b += masking_key
|
||||
return b
|
||||
second_byte = (mask << 7) | length_code
|
||||
|
||||
b = chr(first_byte) + chr(second_byte)
|
||||
|
||||
if self.payload_length < 126:
|
||||
pass
|
||||
elif self.payload_length < MAX_16_BIT_INT:
|
||||
# '!H' pack as 16 bit unsigned short
|
||||
# add 2 byte extended payload length
|
||||
b += struct.pack('!H', self.payload_length)
|
||||
elif self.payload_length < MAX_64_BIT_INT:
|
||||
# '!Q' = pack as 64 bit unsigned long long
|
||||
# add 8 bytes extended payload length
|
||||
b += struct.pack('!Q', self.payload_length)
|
||||
if self.masking_key is not None:
|
||||
b += self.masking_key
|
||||
return b
|
||||
|
||||
@classmethod
|
||||
def from_file(klass, fp):
|
||||
"""
|
||||
read a websockets frame header
|
||||
"""
|
||||
first_byte = utils.bytes_to_int(fp.read(1))
|
||||
second_byte = utils.bytes_to_int(fp.read(1))
|
||||
|
||||
fin = utils.getbit(first_byte, 7)
|
||||
rsv1 = utils.getbit(first_byte, 6)
|
||||
rsv2 = utils.getbit(first_byte, 5)
|
||||
rsv3 = utils.getbit(first_byte, 4)
|
||||
# grab right most 4 bits by and-ing with 00001111
|
||||
opcode = first_byte & 15
|
||||
# grab left most bit
|
||||
mask_bit = second_byte >> 7
|
||||
# grab the next 7 bits
|
||||
length_code = second_byte & 127
|
||||
|
||||
# payload_lengthy > 125 indicates you need to read more bytes
|
||||
# to get the actual payload length
|
||||
if length_code <= 125:
|
||||
payload_length = length_code
|
||||
elif length_code == 126:
|
||||
payload_length = utils.bytes_to_int(fp.read(2))
|
||||
elif length_code == 127:
|
||||
payload_length = utils.bytes_to_int(fp.read(8))
|
||||
|
||||
# masking key only present if mask bit set
|
||||
if mask_bit == 1:
|
||||
masking_key = fp.read(4)
|
||||
else:
|
||||
masking_key = None
|
||||
|
||||
return klass(
|
||||
fin = fin,
|
||||
rsv1 = rsv1,
|
||||
rsv2 = rsv2,
|
||||
rsv3 = rsv3,
|
||||
opcode = opcode,
|
||||
mask = mask_bit,
|
||||
length_code = length_code,
|
||||
payload_length = payload_length,
|
||||
masking_key = masking_key,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.to_bytes() == other.to_bytes()
|
||||
|
||||
|
||||
class Frame(object):
|
||||
@ -194,27 +249,10 @@ class Frame(object):
|
||||
| Payload Data continued ... |
|
||||
+---------------------------------------------------------------+
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
fin, # decmial integer 1 or 0
|
||||
opcode, # decmial integer 1 - 4
|
||||
payload = "", # bytestring
|
||||
masking_key = None, # 32 bit byte string
|
||||
mask_bit = 0, # decimal integer 1 or 0
|
||||
payload_length_code = None, # decimal integer 1 - 127
|
||||
rsv1 = 0, # decimal integer 1 or 0
|
||||
rsv2 = 0, # decimal integer 1 or 0
|
||||
rsv3 = 0, # decimal integer 1 or 0
|
||||
):
|
||||
self.fin = fin
|
||||
self.rsv1 = rsv1
|
||||
self.rsv2 = rsv2
|
||||
self.rsv3 = rsv3
|
||||
self.opcode = opcode
|
||||
self.mask_bit = mask_bit
|
||||
self.payload_length_code = payload_length_code
|
||||
self.masking_key = masking_key
|
||||
def __init__(self, payload = "", **kwargs):
|
||||
self.payload = payload
|
||||
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
|
||||
self.header = FrameHeader(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def default(cls, message, from_client = False):
|
||||
@ -230,10 +268,10 @@ class Frame(object):
|
||||
masking_key = None
|
||||
|
||||
return cls(
|
||||
message,
|
||||
fin = 1, # final frame
|
||||
opcode = OPCODE.TEXT, # text
|
||||
mask_bit = mask_bit,
|
||||
payload = message,
|
||||
mask = mask_bit,
|
||||
masking_key = masking_key,
|
||||
)
|
||||
|
||||
@ -243,30 +281,30 @@ class Frame(object):
|
||||
Frame has not been corrupted.
|
||||
"""
|
||||
constraints = [
|
||||
0 <= self.fin <= 1,
|
||||
0 <= self.rsv1 <= 1,
|
||||
0 <= self.rsv2 <= 1,
|
||||
0 <= self.rsv3 <= 1,
|
||||
1 <= self.opcode <= 4,
|
||||
0 <= self.mask_bit <= 1,
|
||||
0 <= self.header.fin <= 1,
|
||||
0 <= self.header.rsv1 <= 1,
|
||||
0 <= self.header.rsv2 <= 1,
|
||||
0 <= self.header.rsv3 <= 1,
|
||||
1 <= self.header.opcode <= 4,
|
||||
0 <= self.header.mask <= 1,
|
||||
#1 <= self.payload_length_code <= 127,
|
||||
1 <= len(self.masking_key) <= 4 if self.mask_bit else True,
|
||||
self.masking_key is not None if self.mask_bit else True
|
||||
1 <= len(self.header.masking_key) <= 4 if self.header.mask else True,
|
||||
self.header.masking_key is not None if self.header.mask else True
|
||||
]
|
||||
if not all(constraints):
|
||||
return False
|
||||
return True
|
||||
|
||||
def human_readable(self): # pragma: nocover
|
||||
def human_readable(self):
|
||||
return "\n".join([
|
||||
("fin - " + str(self.fin)),
|
||||
("rsv1 - " + str(self.rsv1)),
|
||||
("rsv2 - " + str(self.rsv2)),
|
||||
("rsv3 - " + str(self.rsv3)),
|
||||
("opcode - " + str(self.opcode)),
|
||||
("mask_bit - " + str(self.mask_bit)),
|
||||
("payload_length_code - " + str(self.payload_length_code)),
|
||||
("masking_key - " + repr(str(self.masking_key))),
|
||||
("fin - " + str(self.header.fin)),
|
||||
("rsv1 - " + str(self.header.rsv1)),
|
||||
("rsv2 - " + str(self.header.rsv2)),
|
||||
("rsv3 - " + str(self.header.rsv3)),
|
||||
("opcode - " + str(self.header.opcode)),
|
||||
("mask - " + str(self.header.mask)),
|
||||
("length_code - " + str(self.header.length_code)),
|
||||
("masking_key - " + repr(str(self.header.masking_key))),
|
||||
("payload - " + repr(str(self.payload))),
|
||||
])
|
||||
|
||||
@ -284,18 +322,9 @@ class Frame(object):
|
||||
If you haven't checked is_valid_frame() then there's no guarentees
|
||||
that the serialized bytes will be correct. see safe_to_bytes()
|
||||
"""
|
||||
b = frame_header_bytes(
|
||||
opcode = self.opcode,
|
||||
fin = self.fin,
|
||||
rsv1 = self.rsv1,
|
||||
rsv2 = self.rsv2,
|
||||
rsv3 = self.rsv3,
|
||||
mask = self.mask_bit,
|
||||
masking_key = self.masking_key,
|
||||
payload_length = len(self.payload) if self.payload else 0
|
||||
)
|
||||
if self.masking_key:
|
||||
b += apply_mask(self.payload, self.masking_key)
|
||||
b = self.header.to_bytes()
|
||||
if self.header.masking_key:
|
||||
b += apply_mask(self.payload, self.header.masking_key)
|
||||
else:
|
||||
b += self.payload
|
||||
return b
|
||||
@ -312,66 +341,20 @@ class Frame(object):
|
||||
fp is a "file like" object that could be backed by a network
|
||||
stream or a disk or an in memory stream reader
|
||||
"""
|
||||
first_byte = utils.bytes_to_int(fp.read(1))
|
||||
second_byte = utils.bytes_to_int(fp.read(1))
|
||||
header = FrameHeader.from_file(fp)
|
||||
payload = fp.read(header.payload_length)
|
||||
|
||||
# grab the left most bit
|
||||
fin = first_byte >> 7
|
||||
# grab right most 4 bits by and-ing with 00001111
|
||||
opcode = first_byte & 15
|
||||
# grab left most bit
|
||||
mask_bit = second_byte >> 7
|
||||
# grab the next 7 bits
|
||||
payload_length = second_byte & 127
|
||||
|
||||
# payload_lengthy > 125 indicates you need to read more bytes
|
||||
# to get the actual payload length
|
||||
if payload_length <= 125:
|
||||
actual_payload_length = payload_length
|
||||
|
||||
elif payload_length == 126:
|
||||
actual_payload_length = utils.bytes_to_int(fp.read(2))
|
||||
|
||||
elif payload_length == 127:
|
||||
actual_payload_length = utils.bytes_to_int(fp.read(8))
|
||||
|
||||
# masking key only present if mask bit set
|
||||
if mask_bit == 1:
|
||||
masking_key = fp.read(4)
|
||||
else:
|
||||
masking_key = None
|
||||
|
||||
payload = fp.read(actual_payload_length)
|
||||
|
||||
if mask_bit == 1 and masking_key:
|
||||
payload = apply_mask(payload, masking_key)
|
||||
if header.mask == 1 and header.masking_key:
|
||||
payload = apply_mask(payload, header.masking_key)
|
||||
|
||||
return cls(
|
||||
fin = fin,
|
||||
opcode = opcode,
|
||||
mask_bit = mask_bit,
|
||||
payload_length_code = payload_length,
|
||||
payload = payload,
|
||||
masking_key = masking_key,
|
||||
payload,
|
||||
fin = header.fin,
|
||||
opcode = header.opcode,
|
||||
mask = header.mask,
|
||||
payload_length = header.payload_length,
|
||||
masking_key = header.masking_key,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.payload_length_code is None:
|
||||
myplc = make_length_code(len(self.payload))
|
||||
else:
|
||||
myplc = self.payload_length_code
|
||||
if other.payload_length_code is None:
|
||||
otherplc = make_length_code(len(other.payload))
|
||||
else:
|
||||
otherplc = other.payload_length_code
|
||||
return (
|
||||
self.fin == other.fin and
|
||||
self.rsv1 == other.rsv1 and
|
||||
self.rsv2 == other.rsv2 and
|
||||
self.rsv3 == other.rsv3 and
|
||||
self.opcode == other.opcode and
|
||||
self.mask_bit == other.mask_bit and
|
||||
self.masking_key == other.masking_key and
|
||||
self.payload == other.payload,
|
||||
myplc == otherplc
|
||||
)
|
||||
return self.to_bytes() == other.to_bytes()
|
||||
|
@ -1,10 +1,9 @@
|
||||
from netlib import tcp, test, websockets, http
|
||||
import cStringIO
|
||||
import os
|
||||
|
||||
from nose.tools import raises
|
||||
|
||||
|
||||
def test_frame_header_bytes():
|
||||
assert websockets.frame_header_bytes()
|
||||
from netlib import tcp, test, websockets, http
|
||||
|
||||
|
||||
class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
@ -119,12 +118,12 @@ class TestWebSockets(test.ServerTestBase):
|
||||
assert frame.is_valid()
|
||||
|
||||
frame = f()
|
||||
frame.fin = 2
|
||||
frame.header.fin = 2
|
||||
assert not frame.is_valid()
|
||||
|
||||
frame = f()
|
||||
frame.mask_bit = 1
|
||||
frame.masking_key = "foobbarboo"
|
||||
frame.header.mask_bit = 1
|
||||
frame.header.masking_key = "foobbarboo"
|
||||
assert not frame.is_valid()
|
||||
|
||||
def test_serialization_bijection(self):
|
||||
@ -181,3 +180,39 @@ class TestBadHandshake(test.ServerTestBase):
|
||||
client = WebSocketsClient(("127.0.0.1", self.port))
|
||||
client.connect()
|
||||
client.send_message("hello")
|
||||
|
||||
|
||||
class TestFrameHeader:
|
||||
def test_roundtrip(self):
|
||||
def round(*args, **kwargs):
|
||||
f = websockets.FrameHeader(*args, **kwargs)
|
||||
bytes = f.to_bytes()
|
||||
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
|
||||
assert f == f2
|
||||
round()
|
||||
round(fin=1)
|
||||
round(rsv1=1)
|
||||
round(rsv2=1)
|
||||
round(rsv3=1)
|
||||
round(payload_length=1)
|
||||
round(payload_length=100)
|
||||
round(payload_length=1000)
|
||||
round(payload_length=10000)
|
||||
round(opcode=websockets.OPCODE.PING)
|
||||
round(masking_key="test")
|
||||
|
||||
def test_funky(self):
|
||||
f = websockets.FrameHeader(masking_key="test", mask=False)
|
||||
bytes = f.to_bytes()
|
||||
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
|
||||
assert not f2.mask
|
||||
|
||||
|
||||
class TestFrame:
|
||||
def test_roundtrip(self):
|
||||
def round(*args, **kwargs):
|
||||
f = websockets.Frame(*args, **kwargs)
|
||||
bytes = f.to_bytes()
|
||||
f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes))
|
||||
assert f == f2
|
||||
round("test")
|
||||
|
Loading…
Reference in New Issue
Block a user