mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
websockets: refactor to avoid rundantly specifying payloads and payload lengths
This commit is contained in:
parent
bdd52fead3
commit
3519871f34
@ -198,15 +198,13 @@ class Frame(object):
|
||||
self,
|
||||
fin, # decmial integer 1 or 0
|
||||
opcode, # decmial integer 1 - 4
|
||||
mask_bit, # decimal integer 1 or 0
|
||||
payload_length_code, # decimal integer 1 - 127
|
||||
decoded_payload, # bytestring
|
||||
payload = "", # bytestring
|
||||
masking_key = None, # 32 bit byte string
|
||||
mask_bit = 0, # decimal integer 1 or 0
|
||||
payload_length_code = None, # decimal integer 1 - 127
|
||||
rsv1 = 0, # decimal integer 1 or 0
|
||||
rsv2 = 0, # decimal integer 1 or 0
|
||||
rsv3 = 0, # decimal integer 1 or 0
|
||||
payload = None, # bytestring
|
||||
masking_key = None, # 32 bit byte string
|
||||
actual_payload_length = None, # any decimal integer
|
||||
):
|
||||
self.fin = fin
|
||||
self.rsv1 = rsv1
|
||||
@ -217,8 +215,6 @@ class Frame(object):
|
||||
self.payload_length_code = payload_length_code
|
||||
self.masking_key = masking_key
|
||||
self.payload = payload
|
||||
self.decoded_payload = decoded_payload
|
||||
self.actual_payload_length = actual_payload_length
|
||||
|
||||
@classmethod
|
||||
def default(cls, message, from_client = False):
|
||||
@ -226,27 +222,19 @@ class Frame(object):
|
||||
Construct a basic websocket frame from some default values.
|
||||
Creates a non-fragmented text frame.
|
||||
"""
|
||||
length_code, actual_length = get_payload_length_pair(message)
|
||||
|
||||
if from_client:
|
||||
mask_bit = 1
|
||||
# Random masking key
|
||||
masking_key = os.urandom(4)
|
||||
payload = apply_mask(message, masking_key)
|
||||
else:
|
||||
mask_bit = 0
|
||||
masking_key = None
|
||||
payload = message
|
||||
|
||||
return cls(
|
||||
fin = 1, # final frame
|
||||
opcode = OPCODE.TEXT, # text
|
||||
mask_bit = mask_bit,
|
||||
payload_length_code = length_code,
|
||||
payload = payload,
|
||||
payload = message,
|
||||
masking_key = masking_key,
|
||||
decoded_payload = message,
|
||||
actual_payload_length = actual_length
|
||||
)
|
||||
|
||||
def is_valid(self):
|
||||
@ -261,17 +249,12 @@ class Frame(object):
|
||||
0 <= self.rsv3 <= 1,
|
||||
1 <= self.opcode <= 4,
|
||||
0 <= self.mask_bit <= 1,
|
||||
1 <= self.payload_length_code <= 127,
|
||||
self.actual_payload_length == len(self.payload),
|
||||
#1 <= self.payload_length_code <= 127,
|
||||
1 <= len(self.masking_key) <= 4 if self.mask_bit else True,
|
||||
self.masking_key is not None if self.mask_bit else True
|
||||
]
|
||||
if not all(constraints):
|
||||
return False
|
||||
elif self.payload and self.masking_key:
|
||||
decoded = apply_mask(self.payload, self.masking_key)
|
||||
if decoded != self.decoded_payload:
|
||||
return False
|
||||
return True
|
||||
|
||||
def human_readable(self): # pragma: nocover
|
||||
@ -285,8 +268,6 @@ class Frame(object):
|
||||
("payload_length_code - " + str(self.payload_length_code)),
|
||||
("masking_key - " + repr(str(self.masking_key))),
|
||||
("payload - " + repr(str(self.payload))),
|
||||
("decoded_payload - " + repr(str(self.decoded_payload))),
|
||||
("actual_payload_length - " + str(self.actual_payload_length))
|
||||
])
|
||||
|
||||
@classmethod
|
||||
@ -311,9 +292,12 @@ class Frame(object):
|
||||
rsv3 = self.rsv3,
|
||||
mask = self.mask_bit,
|
||||
masking_key = self.masking_key,
|
||||
payload_length = self.actual_payload_length
|
||||
payload_length = len(self.payload) if self.payload else 0
|
||||
)
|
||||
b += self.payload # already will be encoded if neccessary
|
||||
if self.masking_key:
|
||||
b += apply_mask(self.payload, self.masking_key)
|
||||
else:
|
||||
b += self.payload
|
||||
return b
|
||||
|
||||
def to_file(self, writer):
|
||||
@ -359,10 +343,8 @@ class Frame(object):
|
||||
|
||||
payload = fp.read(actual_payload_length)
|
||||
|
||||
if mask_bit == 1:
|
||||
decoded_payload = apply_mask(payload, masking_key)
|
||||
else:
|
||||
decoded_payload = payload
|
||||
if mask_bit == 1 and masking_key:
|
||||
payload = apply_mask(payload, masking_key)
|
||||
|
||||
return cls(
|
||||
fin = fin,
|
||||
@ -371,11 +353,17 @@ class Frame(object):
|
||||
payload_length_code = payload_length,
|
||||
payload = payload,
|
||||
masking_key = masking_key,
|
||||
decoded_payload = decoded_payload,
|
||||
actual_payload_length = actual_payload_length
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.payload_length_code is None:
|
||||
myplc = make_length_code(len(self.payload))
|
||||
else:
|
||||
myplc = self.payload_length_code
|
||||
if other.payload_length_code is None:
|
||||
otherplc = make_length_code(len(other.payload))
|
||||
else:
|
||||
otherplc = other.payload_length_code
|
||||
return (
|
||||
self.fin == other.fin and
|
||||
self.rsv1 == other.rsv1 and
|
||||
@ -383,9 +371,7 @@ class Frame(object):
|
||||
self.rsv3 == other.rsv3 and
|
||||
self.opcode == other.opcode and
|
||||
self.mask_bit == other.mask_bit and
|
||||
self.payload_length_code == other.payload_length_code and
|
||||
self.masking_key == other.masking_key and
|
||||
self.payload == other.payload and
|
||||
self.decoded_payload == other.decoded_payload and
|
||||
self.actual_payload_length == other.actual_payload_length
|
||||
self.payload == other.payload,
|
||||
myplc == otherplc
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
|
||||
def read_next_message(self):
|
||||
frame = websockets.Frame.from_file(self.rfile)
|
||||
self.on_message(frame.decoded_payload)
|
||||
self.on_message(frame.payload)
|
||||
|
||||
def send_message(self, message):
|
||||
frame = websockets.Frame.default(message, from_client = False)
|
||||
@ -107,7 +107,6 @@ class TestWebSockets(test.ServerTestBase):
|
||||
"""
|
||||
msg = self.random_bytes()
|
||||
client_frame = websockets.Frame.default(msg, from_client = True)
|
||||
assert client_frame.is_valid()
|
||||
|
||||
server_frame = websockets.Frame.default(msg, from_client = False)
|
||||
assert server_frame.is_valid()
|
||||
@ -128,17 +127,6 @@ class TestWebSockets(test.ServerTestBase):
|
||||
frame.masking_key = "foobbarboo"
|
||||
assert not frame.is_valid()
|
||||
|
||||
frame = f()
|
||||
frame.mask_bit = 0
|
||||
frame.masking_key = "foob"
|
||||
assert not frame.is_valid()
|
||||
|
||||
frame = f()
|
||||
frame.masking_key = "foob"
|
||||
frame.decoded_payload = "xxxx"
|
||||
assert not frame.is_valid()
|
||||
|
||||
|
||||
def test_serialization_bijection(self):
|
||||
"""
|
||||
Ensure that various frame types can be serialized/deserialized back
|
||||
@ -149,9 +137,10 @@ class TestWebSockets(test.ServerTestBase):
|
||||
frame = websockets.Frame.default(
|
||||
self.random_bytes(num_bytes), is_client
|
||||
)
|
||||
assert frame == websockets.Frame.from_bytes(
|
||||
frame2 = websockets.Frame.from_bytes(
|
||||
frame.to_bytes()
|
||||
)
|
||||
assert frame == frame2
|
||||
|
||||
bytes = b'\x81\x03cba'
|
||||
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
|
||||
|
Loading…
Reference in New Issue
Block a user