mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-30 03:14:22 +00:00
websocket: interface refactoring
- Separate out FrameHeader. We need to deal with this separately in many circumstances. - Simpler equality scheme. - Bits are now specified by truthiness - we don't care about the integer value. This means lots of validation is not needed any more.
This commit is contained in:
parent
3519871f34
commit
f22bc0b4c7
@ -49,3 +49,19 @@ def hexdump(s):
|
|||||||
(o, x, cleanBin(part, True))
|
(o, x, cleanBin(part, True))
|
||||||
)
|
)
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def setbit(byte, offset, value):
|
||||||
|
"""
|
||||||
|
Set a bit in a byte to 1 if value is truthy, 0 if not.
|
||||||
|
"""
|
||||||
|
if value:
|
||||||
|
return byte | (1 << offset)
|
||||||
|
else:
|
||||||
|
return byte & ~(1 << offset)
|
||||||
|
|
||||||
|
|
||||||
|
def getbit(byte, offset):
|
||||||
|
mask = 1 << offset
|
||||||
|
if byte & mask:
|
||||||
|
return True
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
@ -83,23 +82,6 @@ def server_handshake_headers(key):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_payload_length_pair(payload_bytestring):
|
|
||||||
"""
|
|
||||||
A websockets frame contains an initial length_code, and an optional
|
|
||||||
extended length code to represent the actual length if length code is
|
|
||||||
larger than 125
|
|
||||||
"""
|
|
||||||
actual_length = len(payload_bytestring)
|
|
||||||
|
|
||||||
if actual_length <= 125:
|
|
||||||
length_code = actual_length
|
|
||||||
elif actual_length >= 126 and actual_length <= 65535:
|
|
||||||
length_code = 126
|
|
||||||
else:
|
|
||||||
length_code = 127
|
|
||||||
return (length_code, actual_length)
|
|
||||||
|
|
||||||
|
|
||||||
def make_length_code(len):
|
def make_length_code(len):
|
||||||
"""
|
"""
|
||||||
A websockets frame contains an initial length_code, and an optional
|
A websockets frame contains an initial length_code, and an optional
|
||||||
@ -132,41 +114,114 @@ def create_server_nonce(client_nonce):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def frame_header_bytes(
|
DEFAULT = object()
|
||||||
opcode = 0,
|
class FrameHeader:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
opcode = OPCODE.TEXT,
|
||||||
payload_length = 0,
|
payload_length = 0,
|
||||||
fin = 0,
|
fin = False,
|
||||||
rsv1 = 0,
|
rsv1 = False,
|
||||||
rsv2 = 0,
|
rsv2 = False,
|
||||||
rsv3 = 0,
|
rsv3 = False,
|
||||||
mask = 0,
|
|
||||||
masking_key = None,
|
masking_key = None,
|
||||||
length_code = None
|
mask = DEFAULT,
|
||||||
|
length_code = DEFAULT
|
||||||
):
|
):
|
||||||
first_byte = (fin << 7) | (rsv1 << 6) |\
|
self.opcode = opcode
|
||||||
(rsv2 << 4) | (rsv3 << 4) | opcode
|
self.payload_length = payload_length
|
||||||
|
self.fin = fin
|
||||||
|
self.rsv1 = rsv1
|
||||||
|
self.rsv2 = rsv2
|
||||||
|
self.rsv3 = rsv3
|
||||||
|
self.mask = mask
|
||||||
|
self.masking_key = masking_key
|
||||||
|
self.length_code = length_code
|
||||||
|
|
||||||
if length_code is None:
|
def to_bytes(self):
|
||||||
length_code = make_length_code(payload_length)
|
first_byte = utils.setbit(0, 7, self.fin)
|
||||||
|
first_byte = utils.setbit(first_byte, 6, self.rsv1)
|
||||||
|
first_byte = utils.setbit(first_byte, 5, self.rsv2)
|
||||||
|
first_byte = utils.setbit(first_byte, 4, self.rsv3)
|
||||||
|
first_byte = first_byte | self.opcode
|
||||||
|
|
||||||
|
if self.length_code is DEFAULT:
|
||||||
|
length_code = make_length_code(self.payload_length)
|
||||||
|
else:
|
||||||
|
length_code = self.length_code
|
||||||
|
|
||||||
|
if self.mask is DEFAULT:
|
||||||
|
mask = bool(self.masking_key)
|
||||||
|
else:
|
||||||
|
mask = self.mask
|
||||||
|
|
||||||
second_byte = (mask << 7) | length_code
|
second_byte = (mask << 7) | length_code
|
||||||
|
|
||||||
b = chr(first_byte) + chr(second_byte)
|
b = chr(first_byte) + chr(second_byte)
|
||||||
|
|
||||||
if payload_length < 126:
|
if self.payload_length < 126:
|
||||||
pass
|
pass
|
||||||
elif payload_length < MAX_16_BIT_INT:
|
elif self.payload_length < MAX_16_BIT_INT:
|
||||||
# '!H' pack as 16 bit unsigned short
|
# '!H' pack as 16 bit unsigned short
|
||||||
# add 2 byte extended payload length
|
# add 2 byte extended payload length
|
||||||
b += struct.pack('!H', payload_length)
|
b += struct.pack('!H', self.payload_length)
|
||||||
elif payload_length < MAX_64_BIT_INT:
|
elif self.payload_length < MAX_64_BIT_INT:
|
||||||
# '!Q' = pack as 64 bit unsigned long long
|
# '!Q' = pack as 64 bit unsigned long long
|
||||||
# add 8 bytes extended payload length
|
# add 8 bytes extended payload length
|
||||||
b += struct.pack('!Q', payload_length)
|
b += struct.pack('!Q', self.payload_length)
|
||||||
if masking_key is not None:
|
if self.masking_key is not None:
|
||||||
b += masking_key
|
b += self.masking_key
|
||||||
return b
|
return b
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(klass, fp):
|
||||||
|
"""
|
||||||
|
read a websockets frame header
|
||||||
|
"""
|
||||||
|
first_byte = utils.bytes_to_int(fp.read(1))
|
||||||
|
second_byte = utils.bytes_to_int(fp.read(1))
|
||||||
|
|
||||||
|
fin = utils.getbit(first_byte, 7)
|
||||||
|
rsv1 = utils.getbit(first_byte, 6)
|
||||||
|
rsv2 = utils.getbit(first_byte, 5)
|
||||||
|
rsv3 = utils.getbit(first_byte, 4)
|
||||||
|
# grab right most 4 bits by and-ing with 00001111
|
||||||
|
opcode = first_byte & 15
|
||||||
|
# grab left most bit
|
||||||
|
mask_bit = second_byte >> 7
|
||||||
|
# grab the next 7 bits
|
||||||
|
length_code = second_byte & 127
|
||||||
|
|
||||||
|
# payload_lengthy > 125 indicates you need to read more bytes
|
||||||
|
# to get the actual payload length
|
||||||
|
if length_code <= 125:
|
||||||
|
payload_length = length_code
|
||||||
|
elif length_code == 126:
|
||||||
|
payload_length = utils.bytes_to_int(fp.read(2))
|
||||||
|
elif length_code == 127:
|
||||||
|
payload_length = utils.bytes_to_int(fp.read(8))
|
||||||
|
|
||||||
|
# masking key only present if mask bit set
|
||||||
|
if mask_bit == 1:
|
||||||
|
masking_key = fp.read(4)
|
||||||
|
else:
|
||||||
|
masking_key = None
|
||||||
|
|
||||||
|
return klass(
|
||||||
|
fin = fin,
|
||||||
|
rsv1 = rsv1,
|
||||||
|
rsv2 = rsv2,
|
||||||
|
rsv3 = rsv3,
|
||||||
|
opcode = opcode,
|
||||||
|
mask = mask_bit,
|
||||||
|
length_code = length_code,
|
||||||
|
payload_length = payload_length,
|
||||||
|
masking_key = masking_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.to_bytes() == other.to_bytes()
|
||||||
|
|
||||||
|
|
||||||
class Frame(object):
|
class Frame(object):
|
||||||
"""
|
"""
|
||||||
@ -194,27 +249,10 @@ class Frame(object):
|
|||||||
| Payload Data continued ... |
|
| Payload Data continued ... |
|
||||||
+---------------------------------------------------------------+
|
+---------------------------------------------------------------+
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(self, payload = "", **kwargs):
|
||||||
self,
|
|
||||||
fin, # decmial integer 1 or 0
|
|
||||||
opcode, # decmial integer 1 - 4
|
|
||||||
payload = "", # bytestring
|
|
||||||
masking_key = None, # 32 bit byte string
|
|
||||||
mask_bit = 0, # decimal integer 1 or 0
|
|
||||||
payload_length_code = None, # decimal integer 1 - 127
|
|
||||||
rsv1 = 0, # decimal integer 1 or 0
|
|
||||||
rsv2 = 0, # decimal integer 1 or 0
|
|
||||||
rsv3 = 0, # decimal integer 1 or 0
|
|
||||||
):
|
|
||||||
self.fin = fin
|
|
||||||
self.rsv1 = rsv1
|
|
||||||
self.rsv2 = rsv2
|
|
||||||
self.rsv3 = rsv3
|
|
||||||
self.opcode = opcode
|
|
||||||
self.mask_bit = mask_bit
|
|
||||||
self.payload_length_code = payload_length_code
|
|
||||||
self.masking_key = masking_key
|
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
|
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
|
||||||
|
self.header = FrameHeader(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default(cls, message, from_client = False):
|
def default(cls, message, from_client = False):
|
||||||
@ -230,10 +268,10 @@ class Frame(object):
|
|||||||
masking_key = None
|
masking_key = None
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
message,
|
||||||
fin = 1, # final frame
|
fin = 1, # final frame
|
||||||
opcode = OPCODE.TEXT, # text
|
opcode = OPCODE.TEXT, # text
|
||||||
mask_bit = mask_bit,
|
mask = mask_bit,
|
||||||
payload = message,
|
|
||||||
masking_key = masking_key,
|
masking_key = masking_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -243,30 +281,30 @@ class Frame(object):
|
|||||||
Frame has not been corrupted.
|
Frame has not been corrupted.
|
||||||
"""
|
"""
|
||||||
constraints = [
|
constraints = [
|
||||||
0 <= self.fin <= 1,
|
0 <= self.header.fin <= 1,
|
||||||
0 <= self.rsv1 <= 1,
|
0 <= self.header.rsv1 <= 1,
|
||||||
0 <= self.rsv2 <= 1,
|
0 <= self.header.rsv2 <= 1,
|
||||||
0 <= self.rsv3 <= 1,
|
0 <= self.header.rsv3 <= 1,
|
||||||
1 <= self.opcode <= 4,
|
1 <= self.header.opcode <= 4,
|
||||||
0 <= self.mask_bit <= 1,
|
0 <= self.header.mask <= 1,
|
||||||
#1 <= self.payload_length_code <= 127,
|
#1 <= self.payload_length_code <= 127,
|
||||||
1 <= len(self.masking_key) <= 4 if self.mask_bit else True,
|
1 <= len(self.header.masking_key) <= 4 if self.header.mask else True,
|
||||||
self.masking_key is not None if self.mask_bit else True
|
self.header.masking_key is not None if self.header.mask else True
|
||||||
]
|
]
|
||||||
if not all(constraints):
|
if not all(constraints):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def human_readable(self): # pragma: nocover
|
def human_readable(self):
|
||||||
return "\n".join([
|
return "\n".join([
|
||||||
("fin - " + str(self.fin)),
|
("fin - " + str(self.header.fin)),
|
||||||
("rsv1 - " + str(self.rsv1)),
|
("rsv1 - " + str(self.header.rsv1)),
|
||||||
("rsv2 - " + str(self.rsv2)),
|
("rsv2 - " + str(self.header.rsv2)),
|
||||||
("rsv3 - " + str(self.rsv3)),
|
("rsv3 - " + str(self.header.rsv3)),
|
||||||
("opcode - " + str(self.opcode)),
|
("opcode - " + str(self.header.opcode)),
|
||||||
("mask_bit - " + str(self.mask_bit)),
|
("mask - " + str(self.header.mask)),
|
||||||
("payload_length_code - " + str(self.payload_length_code)),
|
("length_code - " + str(self.header.length_code)),
|
||||||
("masking_key - " + repr(str(self.masking_key))),
|
("masking_key - " + repr(str(self.header.masking_key))),
|
||||||
("payload - " + repr(str(self.payload))),
|
("payload - " + repr(str(self.payload))),
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -284,18 +322,9 @@ class Frame(object):
|
|||||||
If you haven't checked is_valid_frame() then there's no guarentees
|
If you haven't checked is_valid_frame() then there's no guarentees
|
||||||
that the serialized bytes will be correct. see safe_to_bytes()
|
that the serialized bytes will be correct. see safe_to_bytes()
|
||||||
"""
|
"""
|
||||||
b = frame_header_bytes(
|
b = self.header.to_bytes()
|
||||||
opcode = self.opcode,
|
if self.header.masking_key:
|
||||||
fin = self.fin,
|
b += apply_mask(self.payload, self.header.masking_key)
|
||||||
rsv1 = self.rsv1,
|
|
||||||
rsv2 = self.rsv2,
|
|
||||||
rsv3 = self.rsv3,
|
|
||||||
mask = self.mask_bit,
|
|
||||||
masking_key = self.masking_key,
|
|
||||||
payload_length = len(self.payload) if self.payload else 0
|
|
||||||
)
|
|
||||||
if self.masking_key:
|
|
||||||
b += apply_mask(self.payload, self.masking_key)
|
|
||||||
else:
|
else:
|
||||||
b += self.payload
|
b += self.payload
|
||||||
return b
|
return b
|
||||||
@ -312,66 +341,20 @@ class Frame(object):
|
|||||||
fp is a "file like" object that could be backed by a network
|
fp is a "file like" object that could be backed by a network
|
||||||
stream or a disk or an in memory stream reader
|
stream or a disk or an in memory stream reader
|
||||||
"""
|
"""
|
||||||
first_byte = utils.bytes_to_int(fp.read(1))
|
header = FrameHeader.from_file(fp)
|
||||||
second_byte = utils.bytes_to_int(fp.read(1))
|
payload = fp.read(header.payload_length)
|
||||||
|
|
||||||
# grab the left most bit
|
if header.mask == 1 and header.masking_key:
|
||||||
fin = first_byte >> 7
|
payload = apply_mask(payload, header.masking_key)
|
||||||
# grab right most 4 bits by and-ing with 00001111
|
|
||||||
opcode = first_byte & 15
|
|
||||||
# grab left most bit
|
|
||||||
mask_bit = second_byte >> 7
|
|
||||||
# grab the next 7 bits
|
|
||||||
payload_length = second_byte & 127
|
|
||||||
|
|
||||||
# payload_lengthy > 125 indicates you need to read more bytes
|
|
||||||
# to get the actual payload length
|
|
||||||
if payload_length <= 125:
|
|
||||||
actual_payload_length = payload_length
|
|
||||||
|
|
||||||
elif payload_length == 126:
|
|
||||||
actual_payload_length = utils.bytes_to_int(fp.read(2))
|
|
||||||
|
|
||||||
elif payload_length == 127:
|
|
||||||
actual_payload_length = utils.bytes_to_int(fp.read(8))
|
|
||||||
|
|
||||||
# masking key only present if mask bit set
|
|
||||||
if mask_bit == 1:
|
|
||||||
masking_key = fp.read(4)
|
|
||||||
else:
|
|
||||||
masking_key = None
|
|
||||||
|
|
||||||
payload = fp.read(actual_payload_length)
|
|
||||||
|
|
||||||
if mask_bit == 1 and masking_key:
|
|
||||||
payload = apply_mask(payload, masking_key)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
fin = fin,
|
payload,
|
||||||
opcode = opcode,
|
fin = header.fin,
|
||||||
mask_bit = mask_bit,
|
opcode = header.opcode,
|
||||||
payload_length_code = payload_length,
|
mask = header.mask,
|
||||||
payload = payload,
|
payload_length = header.payload_length,
|
||||||
masking_key = masking_key,
|
masking_key = header.masking_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if self.payload_length_code is None:
|
return self.to_bytes() == other.to_bytes()
|
||||||
myplc = make_length_code(len(self.payload))
|
|
||||||
else:
|
|
||||||
myplc = self.payload_length_code
|
|
||||||
if other.payload_length_code is None:
|
|
||||||
otherplc = make_length_code(len(other.payload))
|
|
||||||
else:
|
|
||||||
otherplc = other.payload_length_code
|
|
||||||
return (
|
|
||||||
self.fin == other.fin and
|
|
||||||
self.rsv1 == other.rsv1 and
|
|
||||||
self.rsv2 == other.rsv2 and
|
|
||||||
self.rsv3 == other.rsv3 and
|
|
||||||
self.opcode == other.opcode and
|
|
||||||
self.mask_bit == other.mask_bit and
|
|
||||||
self.masking_key == other.masking_key and
|
|
||||||
self.payload == other.payload,
|
|
||||||
myplc == otherplc
|
|
||||||
)
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
from netlib import tcp, test, websockets, http
|
import cStringIO
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from nose.tools import raises
|
from nose.tools import raises
|
||||||
|
|
||||||
|
from netlib import tcp, test, websockets, http
|
||||||
def test_frame_header_bytes():
|
|
||||||
assert websockets.frame_header_bytes()
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketsEchoHandler(tcp.BaseHandler):
|
class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||||
@ -119,12 +118,12 @@ class TestWebSockets(test.ServerTestBase):
|
|||||||
assert frame.is_valid()
|
assert frame.is_valid()
|
||||||
|
|
||||||
frame = f()
|
frame = f()
|
||||||
frame.fin = 2
|
frame.header.fin = 2
|
||||||
assert not frame.is_valid()
|
assert not frame.is_valid()
|
||||||
|
|
||||||
frame = f()
|
frame = f()
|
||||||
frame.mask_bit = 1
|
frame.header.mask_bit = 1
|
||||||
frame.masking_key = "foobbarboo"
|
frame.header.masking_key = "foobbarboo"
|
||||||
assert not frame.is_valid()
|
assert not frame.is_valid()
|
||||||
|
|
||||||
def test_serialization_bijection(self):
|
def test_serialization_bijection(self):
|
||||||
@ -181,3 +180,39 @@ class TestBadHandshake(test.ServerTestBase):
|
|||||||
client = WebSocketsClient(("127.0.0.1", self.port))
|
client = WebSocketsClient(("127.0.0.1", self.port))
|
||||||
client.connect()
|
client.connect()
|
||||||
client.send_message("hello")
|
client.send_message("hello")
|
||||||
|
|
||||||
|
|
||||||
|
class TestFrameHeader:
|
||||||
|
def test_roundtrip(self):
|
||||||
|
def round(*args, **kwargs):
|
||||||
|
f = websockets.FrameHeader(*args, **kwargs)
|
||||||
|
bytes = f.to_bytes()
|
||||||
|
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
|
||||||
|
assert f == f2
|
||||||
|
round()
|
||||||
|
round(fin=1)
|
||||||
|
round(rsv1=1)
|
||||||
|
round(rsv2=1)
|
||||||
|
round(rsv3=1)
|
||||||
|
round(payload_length=1)
|
||||||
|
round(payload_length=100)
|
||||||
|
round(payload_length=1000)
|
||||||
|
round(payload_length=10000)
|
||||||
|
round(opcode=websockets.OPCODE.PING)
|
||||||
|
round(masking_key="test")
|
||||||
|
|
||||||
|
def test_funky(self):
|
||||||
|
f = websockets.FrameHeader(masking_key="test", mask=False)
|
||||||
|
bytes = f.to_bytes()
|
||||||
|
f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes))
|
||||||
|
assert not f2.mask
|
||||||
|
|
||||||
|
|
||||||
|
class TestFrame:
|
||||||
|
def test_roundtrip(self):
|
||||||
|
def round(*args, **kwargs):
|
||||||
|
f = websockets.Frame(*args, **kwargs)
|
||||||
|
bytes = f.to_bytes()
|
||||||
|
f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes))
|
||||||
|
assert f == f2
|
||||||
|
round("test")
|
||||||
|
Loading…
Reference in New Issue
Block a user