mitmproxy/test/websockets/test_websockets.py

268 lines
8.2 KiB
Python
Raw Normal View History

2015-04-11 18:35:15 +00:00
import os
2015-04-10 02:35:40 +00:00
2015-09-15 22:04:23 +00:00
from netlib.http.http1 import read_response, read_request
from netlib import tcp, tutils, websockets, http
2015-08-05 19:32:53 +00:00
from netlib.http import status_codes
2015-09-15 22:04:23 +00:00
from netlib.tutils import treq
from netlib.exceptions import *
2015-08-01 12:49:15 +00:00
from .. import tservers
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(
connection, address, server
)
2015-07-08 07:34:10 +00:00
self.protocol = websockets.WebsocketsProtocol()
self.handshake_done = False
def handle(self):
while True:
if not self.handshake_done:
self.handshake()
else:
self.read_next_message()
def read_next_message(self):
frame = websockets.Frame.from_file(self.rfile)
self.on_message(frame.payload)
def send_message(self, message):
frame = websockets.Frame.default(message, from_client=False)
frame.to_file(self.wfile)
def handshake(self):
2015-07-16 20:56:34 +00:00
2015-09-15 22:04:23 +00:00
req = read_request(self.rfile)
2015-07-08 07:34:10 +00:00
key = self.protocol.check_client_handshake(req.headers)
2015-08-05 19:32:53 +00:00
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
2015-09-20 22:44:17 +00:00
self.wfile.write(preamble.encode() + b"\r\n")
2015-07-08 07:34:10 +00:00
headers = self.protocol.server_handshake_headers(key)
2015-09-05 16:15:47 +00:00
self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
self.handshake_done = True
def on_message(self, message):
if message is not None:
self.send_message(message)
class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
2015-07-08 07:34:10 +00:00
self.protocol = websockets.WebsocketsProtocol()
self.client_nonce = None
def connect(self):
super(WebSocketsClient, self).connect()
2015-09-20 22:44:17 +00:00
preamble = b'GET / HTTP/1.1'
self.wfile.write(preamble + b"\r\n")
2015-07-08 07:34:10 +00:00
headers = self.protocol.client_handshake_headers()
self.client_nonce = headers["sec-websocket-key"].encode("ascii")
2015-09-20 22:44:17 +00:00
self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
2015-09-25 22:39:04 +00:00
resp = read_response(self.rfile, treq(method=b"GET"))
2015-07-08 07:34:10 +00:00
server_nonce = self.protocol.check_server_handshake(resp.headers)
if not server_nonce == self.protocol.create_server_nonce(self.client_nonce):
self.close()
def read_next_message(self):
return websockets.Frame.from_file(self.rfile).payload
def send_message(self, message):
frame = websockets.Frame.default(message, from_client=True)
frame.to_file(self.wfile)
2015-06-22 02:52:23 +00:00
class TestWebSockets(tservers.ServerTestBase):
handler = WebSocketsEchoHandler
2015-04-10 02:35:40 +00:00
2015-07-08 07:34:10 +00:00
def __init__(self):
self.protocol = websockets.WebsocketsProtocol()
def random_bytes(self, n=100):
2015-04-11 22:40:18 +00:00
return os.urandom(n)
2015-04-11 18:35:15 +00:00
def echo(self, msg):
client = WebSocketsClient(("127.0.0.1", self.port))
2015-04-10 02:35:40 +00:00
client.connect()
client.send_message(msg)
response = client.read_next_message()
assert response == msg
2015-04-11 18:35:15 +00:00
def test_simple_echo(self):
2015-09-20 22:44:17 +00:00
self.echo(b"hello I'm the client")
2015-04-11 18:35:15 +00:00
def test_frame_sizes(self):
# length can fit in the the 7 bit payload length
small_msg = self.random_bytes(100)
# 50kb, sligthly larger than can fit in a 7 bit int
medium_msg = self.random_bytes(50000)
# 150kb, slightly larger than can fit in a 16 bit int
large_msg = self.random_bytes(150000)
2015-04-11 22:40:18 +00:00
self.echo(small_msg)
self.echo(medium_msg)
self.echo(large_msg)
def test_default_builder(self):
"""
default builder should always generate valid frames
"""
2015-04-11 22:40:18 +00:00
msg = self.random_bytes()
client_frame = websockets.Frame.default(msg, from_client=True)
server_frame = websockets.Frame.default(msg, from_client=False)
2015-04-11 22:40:18 +00:00
def test_serialization_bijection(self):
"""
Ensure that various frame types can be serialized/deserialized back
and forth between to_bytes() and from_bytes()
"""
2015-04-11 22:40:18 +00:00
for is_client in [True, False]:
for num_bytes in [100, 50000, 150000]:
frame = websockets.Frame.default(
self.random_bytes(num_bytes), is_client
)
frame2 = websockets.Frame.from_bytes(
frame.to_bytes()
)
assert frame == frame2
2015-04-11 22:40:18 +00:00
bytes = b'\x81\x03cba'
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
def test_check_server_handshake(self):
2015-07-08 07:34:10 +00:00
headers = self.protocol.server_handshake_headers("key")
assert self.protocol.check_server_handshake(headers)
2015-09-05 16:15:47 +00:00
headers["Upgrade"] = "not_websocket"
2015-07-08 07:34:10 +00:00
assert not self.protocol.check_server_handshake(headers)
def test_check_client_handshake(self):
2015-07-08 07:34:10 +00:00
headers = self.protocol.client_handshake_headers("key")
assert self.protocol.check_client_handshake(headers) == "key"
2015-09-05 16:15:47 +00:00
headers["Upgrade"] = "not_websocket"
2015-07-08 07:34:10 +00:00
assert not self.protocol.check_client_handshake(headers)
class BadHandshakeHandler(WebSocketsEchoHandler):
2015-04-11 22:40:18 +00:00
def handshake(self):
2015-07-16 20:56:34 +00:00
2015-09-15 22:04:23 +00:00
client_hs = read_request(self.rfile)
2015-07-08 07:34:10 +00:00
self.protocol.check_client_handshake(client_hs.headers)
2015-09-20 22:44:17 +00:00
preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble.encode())
headers = self.protocol.server_handshake_headers(b"malformed key")
self.wfile.write(bytes(headers) + b"\r\n")
2015-04-11 22:40:18 +00:00
self.wfile.flush()
self.handshake_done = True
2015-06-22 02:52:23 +00:00
class TestBadHandshake(tservers.ServerTestBase):
2015-04-11 22:40:18 +00:00
"""
Ensure that the client disconnects if the server handshake is malformed
"""
2015-04-11 22:40:18 +00:00
handler = BadHandshakeHandler
def test(self):
2015-09-20 23:22:05 +00:00
with tutils.raises(TcpDisconnect):
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message(b"hello")
class TestFrameHeader:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.FrameHeader(*args, **kwargs)
2015-09-20 22:44:17 +00:00
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
assert f == f2
round()
round(fin=1)
round(rsv1=1)
round(rsv2=1)
round(rsv3=1)
round(payload_length=1)
round(payload_length=100)
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
2015-09-20 22:44:17 +00:00
round(masking_key=b"test")
def test_human_readable(self):
f = websockets.FrameHeader(
2015-09-20 22:44:17 +00:00
masking_key=b"test",
fin=True,
payload_length=10
)
assert repr(f)
f = websockets.FrameHeader()
assert repr(f)
def test_funky(self):
2015-09-20 22:44:17 +00:00
f = websockets.FrameHeader(masking_key=b"test", mask=False)
raw = bytes(f)
f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
assert not f2.mask
def test_violations(self):
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
2015-09-20 22:44:17 +00:00
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
def test_automask(self):
f = websockets.FrameHeader(mask=True)
assert f.masking_key
2015-09-20 22:44:17 +00:00
f = websockets.FrameHeader(masking_key=b"foob")
assert f.mask
2015-09-20 22:44:17 +00:00
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
assert not f.mask
assert f.masking_key
class TestFrame:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
2015-09-20 22:44:17 +00:00
raw = bytes(f)
f2 = websockets.Frame.from_file(tutils.treader(raw))
assert f == f2
2015-09-20 22:44:17 +00:00
round(b"test")
round(b"test", fin=1)
round(b"test", rsv1=1)
round(b"test", opcode=websockets.OPCODE.PING)
round(b"test", masking_key=b"test")
def test_human_readable(self):
f = websockets.Frame()
2015-09-20 22:44:17 +00:00
assert repr(f)
2015-04-30 22:09:35 +00:00
def test_masker():
tests = [
2015-09-20 22:44:17 +00:00
[b"a"],
[b"four"],
[b"fourf"],
[b"fourfive"],
[b"a", b"aasdfasdfa", b"asdf"],
[b"a" * 50, b"aasdfasdfa", b"asdf"],
2015-04-30 22:09:35 +00:00
]
for i in tests:
2015-09-20 22:44:17 +00:00
m = websockets.Masker(b"abcd")
data = b"".join([m(t) for t in i])
data2 = websockets.Masker(b"abcd")(data)
assert data2 == b"".join(i)