mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 16:17:49 +00:00
203 lines
6.5 KiB
Python
203 lines
6.5 KiB
Python
from netlib import tcp
|
|
from netlib import test
|
|
from netlib import websockets
|
|
import io
|
|
import os
|
|
from nose.tools import raises
|
|
|
|
|
|
class WebSocketsEchoHandler(tcp.BaseHandler):
|
|
def __init__(self, connection, address, server):
|
|
super(WebSocketsEchoHandler, self).__init__(
|
|
connection, address, server
|
|
)
|
|
self.handshake_done = False
|
|
|
|
def handle(self):
|
|
while True:
|
|
if not self.handshake_done:
|
|
self.handshake()
|
|
else:
|
|
self.read_next_message()
|
|
|
|
def read_next_message(self):
|
|
decoded = websockets.Frame.from_file(self.rfile).decoded_payload
|
|
self.on_message(decoded)
|
|
|
|
def send_message(self, message):
|
|
frame = websockets.Frame.default(message, from_client = False)
|
|
frame.to_file(self.wfile)
|
|
|
|
def handshake(self):
|
|
client_hs = websockets.read_handshake(self.rfile, 1)
|
|
key = websockets.process_handshake_from_client(client_hs)
|
|
response = websockets.create_server_handshake(key)
|
|
self.wfile.write(response)
|
|
self.wfile.flush()
|
|
self.handshake_done = True
|
|
|
|
def on_message(self, message):
|
|
if message is not None:
|
|
self.send_message(message)
|
|
|
|
|
|
class WebSocketsClient(tcp.TCPClient):
|
|
def __init__(self, address, source_address=None):
|
|
super(WebSocketsClient, self).__init__(address, source_address)
|
|
self.version = "13"
|
|
self.client_nonce = websockets.create_client_nonce()
|
|
self.resource = "/"
|
|
|
|
def connect(self):
|
|
super(WebSocketsClient, self).connect()
|
|
|
|
handshake = websockets.create_client_handshake(
|
|
self.address.host,
|
|
self.address.port,
|
|
self.client_nonce,
|
|
self.version,
|
|
self.resource
|
|
)
|
|
|
|
self.wfile.write(handshake)
|
|
self.wfile.flush()
|
|
|
|
server_handshake = websockets.read_handshake(self.rfile, 1)
|
|
server_nonce = websockets.process_handshake_from_server(
|
|
server_handshake
|
|
)
|
|
|
|
if not server_nonce == websockets.create_server_nonce(self.client_nonce):
|
|
self.close()
|
|
|
|
def read_next_message(self):
|
|
return websockets.Frame.from_file(self.rfile).payload
|
|
|
|
def send_message(self, message):
|
|
frame = websockets.Frame.default(message, from_client = True)
|
|
frame.to_file(self.wfile)
|
|
|
|
|
|
class TestWebSockets(test.ServerTestBase):
|
|
handler = WebSocketsEchoHandler
|
|
|
|
def random_bytes(self, n = 100):
|
|
return os.urandom(n)
|
|
|
|
def echo(self, msg):
|
|
client = WebSocketsClient(("127.0.0.1", self.port))
|
|
client.connect()
|
|
client.send_message(msg)
|
|
response = client.read_next_message()
|
|
assert response == msg
|
|
|
|
def test_simple_echo(self):
|
|
self.echo("hello I'm the client")
|
|
|
|
def test_frame_sizes(self):
|
|
# length can fit in the the 7 bit payload length
|
|
small_msg = self.random_bytes(100)
|
|
# 50kb, sligthly larger than can fit in a 7 bit int
|
|
medium_msg = self.random_bytes(50000)
|
|
# 150kb, slightly larger than can fit in a 16 bit int
|
|
large_msg = self.random_bytes(150000)
|
|
|
|
self.echo(small_msg)
|
|
self.echo(medium_msg)
|
|
self.echo(large_msg)
|
|
|
|
def test_default_builder(self):
|
|
"""
|
|
default builder should always generate valid frames
|
|
"""
|
|
msg = self.random_bytes()
|
|
client_frame = websockets.Frame.default(msg, from_client = True)
|
|
assert client_frame.is_valid()
|
|
|
|
server_frame = websockets.Frame.default(msg, from_client = False)
|
|
assert server_frame.is_valid()
|
|
|
|
def test_serialization_bijection(self):
|
|
"""
|
|
Ensure that various frame types can be serialized/deserialized back
|
|
and forth between to_bytes() and from_bytes()
|
|
"""
|
|
for is_client in [True, False]:
|
|
for num_bytes in [100, 50000, 150000]:
|
|
frame = websockets.Frame.default(
|
|
self.random_bytes(num_bytes), is_client
|
|
)
|
|
assert frame == websockets.Frame.from_bytes(
|
|
frame.safe_to_bytes()
|
|
)
|
|
|
|
bytes = b'\x81\x03cba'
|
|
assert websockets.Frame.from_bytes(bytes).safe_to_bytes() == bytes
|
|
|
|
@raises(websockets.WebSocketFrameValidationException)
|
|
def test_safe_to_bytes(self):
|
|
frame = websockets.Frame.default(self.random_bytes(8))
|
|
frame.actual_payload_length = 1 # corrupt the frame
|
|
frame.safe_to_bytes()
|
|
|
|
def test_handshake(self):
|
|
bad_upgrade = "not_websockets"
|
|
bad_header_handshake = websockets.build_handshake([
|
|
('Host', '%s:%s' % ("a", "b")),
|
|
('Connection', "c"),
|
|
('Upgrade', bad_upgrade),
|
|
('Sec-WebSocket-Key', "d"),
|
|
('Sec-WebSocket-Version', "e")
|
|
], "f")
|
|
|
|
# check behavior when required header values are missing
|
|
assert None is websockets.process_handshake_from_server(
|
|
bad_header_handshake
|
|
)
|
|
assert None is websockets.process_handshake_from_client(
|
|
bad_header_handshake
|
|
)
|
|
|
|
key = "test_key"
|
|
|
|
client_handshake = websockets.create_client_handshake(
|
|
"a", "b", key, "d", "e"
|
|
)
|
|
assert key == websockets.process_handshake_from_client(
|
|
client_handshake
|
|
)
|
|
|
|
server_handshake = websockets.create_server_handshake(key)
|
|
assert websockets.create_server_nonce(key) == websockets.process_handshake_from_server(server_handshake)
|
|
|
|
handshake = websockets.create_client_handshake("a", "b", "c", "d", "e")
|
|
stream = io.BytesIO(handshake)
|
|
assert handshake == websockets.read_handshake(stream, 1)
|
|
|
|
# ensure readhandshake doesn't loop forever on empty stream
|
|
empty_stream = io.BytesIO("")
|
|
assert "" == websockets.read_handshake(empty_stream, 1)
|
|
|
|
|
|
class BadHandshakeHandler(WebSocketsEchoHandler):
|
|
def handshake(self):
|
|
client_hs = websockets.read_handshake(self.rfile, 1)
|
|
websockets.process_handshake_from_client(client_hs)
|
|
response = websockets.create_server_handshake("malformed_key")
|
|
self.wfile.write(response)
|
|
self.wfile.flush()
|
|
self.handshake_done = True
|
|
|
|
|
|
class TestBadHandshake(test.ServerTestBase):
|
|
"""
|
|
Ensure that the client disconnects if the server handshake is malformed
|
|
"""
|
|
handler = BadHandshakeHandler
|
|
|
|
@raises(tcp.NetLibDisconnect)
|
|
def test(self):
|
|
client = WebSocketsClient(("127.0.0.1", self.port))
|
|
client.connect()
|
|
client.send_message("hello")
|