Merge pull request #1797 from Kriechi/websocket++

Pass Autobahn WebSocket Test Suite
This commit is contained in:
Maximilian Hils 2016-12-01 09:28:18 +01:00 committed by GitHub
commit d658783dec
3 changed files with 67 additions and 22 deletions

View File

@ -37,6 +37,7 @@ import sys
import functools import functools
from mitmproxy import http from mitmproxy import http
from mitmproxy import websocket
from mitmproxy import tcp from mitmproxy import tcp
from mitmproxy import flow from mitmproxy import flow
@ -99,6 +100,14 @@ class FHTTP(_Action):
return True return True
class FWebSocket(_Action):
code = "websocket"
help = "Match WebSocket flows"
@only(websocket.WebSocketFlow)
def __call__(self, f):
return True
class FTCP(_Action): class FTCP(_Action):
code = "tcp" code = "tcp"
help = "Match TCP flows" help = "Match TCP flows"
@ -245,7 +254,7 @@ class FBod(_Rex):
help = "Body" help = "Body"
flags = re.DOTALL flags = re.DOTALL
@only(http.HTTPFlow, tcp.TCPFlow) @only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow)
def __call__(self, f): def __call__(self, f):
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if f.request and f.request.raw_content: if f.request and f.request.raw_content:
@ -254,7 +263,7 @@ class FBod(_Rex):
if f.response and f.response.raw_content: if f.response and f.response.raw_content:
if self.re.search(f.response.get_content(strict=False)): if self.re.search(f.response.get_content(strict=False)):
return True return True
elif isinstance(f, tcp.TCPFlow): elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow):
for msg in f.messages: for msg in f.messages:
if self.re.search(msg.content): if self.re.search(msg.content):
return True return True
@ -266,13 +275,13 @@ class FBodRequest(_Rex):
help = "Request body" help = "Request body"
flags = re.DOTALL flags = re.DOTALL
@only(http.HTTPFlow, tcp.TCPFlow) @only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow)
def __call__(self, f): def __call__(self, f):
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if f.request and f.request.raw_content: if f.request and f.request.raw_content:
if self.re.search(f.request.get_content(strict=False)): if self.re.search(f.request.get_content(strict=False)):
return True return True
elif isinstance(f, tcp.TCPFlow): elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow):
for msg in f.messages: for msg in f.messages:
if msg.from_client and self.re.search(msg.content): if msg.from_client and self.re.search(msg.content):
return True return True
@ -283,13 +292,13 @@ class FBodResponse(_Rex):
help = "Response body" help = "Response body"
flags = re.DOTALL flags = re.DOTALL
@only(http.HTTPFlow, tcp.TCPFlow) @only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow)
def __call__(self, f): def __call__(self, f):
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if f.response and f.response.raw_content: if f.response and f.response.raw_content:
if self.re.search(f.response.get_content(strict=False)): if self.re.search(f.response.get_content(strict=False)):
return True return True
elif isinstance(f, tcp.TCPFlow): elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow):
for msg in f.messages: for msg in f.messages:
if not msg.from_client and self.re.search(msg.content): if not msg.from_client and self.re.search(msg.content):
return True return True

View File

@ -59,30 +59,56 @@ class WebSocketLayer(base.Layer):
fb.append(frame) fb.append(frame)
if frame.header.fin: if frame.header.fin:
if frame.header.opcode == websockets.OPCODE.TEXT: payload = b''.join(f.payload for f in fb)
original_chunk_sizes = [len(f.payload) for f in fb]
message_type = fb[0].header.opcode
compressed_message = fb[0].header.rsv1
fb.clear()
if message_type == websockets.OPCODE.TEXT:
t = WebSocketTextMessage t = WebSocketTextMessage
else: else:
t = WebSocketBinaryMessage t = WebSocketBinaryMessage
payload = b''.join(f.payload for f in fb)
fb.clear()
websocket_message = t(self.flow, not is_server, payload) websocket_message = t(self.flow, not is_server, payload)
length = len(websocket_message.content)
self.flow.messages.append(websocket_message) self.flow.messages.append(websocket_message)
self.channel.ask("websocket_message", self.flow) self.channel.ask("websocket_message", self.flow)
# chunk payload into multiple 10kB frames, and send them def get_chunk(payload):
payload = websocket_message.content if len(payload) == length:
chunk_size = 10240 # 10kB # message has the same length, we can reuse the same sizes
pos = 0
for s in original_chunk_sizes:
yield payload[pos:pos + s]
pos += s
else:
# just re-chunk everything into 10kB frames
chunk_size = 10240
chunks = range(0, len(payload), chunk_size) chunks = range(0, len(payload), chunk_size)
for i in chunks:
yield payload[i:i + chunk_size]
frms = [ frms = [
websockets.Frame( websockets.Frame(
payload=payload[i:i + chunk_size], payload=chunk,
opcode=frame.header.opcode, opcode=frame.header.opcode,
mask=(False if is_server else 1), mask=(False if is_server else 1),
masking_key=(b'' if is_server else os.urandom(4))) for i in chunks masking_key=(b'' if is_server else os.urandom(4)))
for chunk in get_chunk(websocket_message.content)
] ]
frms[-1].header.fin = 1
if len(frms) > 0:
frms[-1].header.fin = True
else:
frms.append(websockets.Frame(
fin=True,
opcode=websockets.OPCODE.CONTINUE,
mask=(False if is_server else 1),
masking_key=(b'' if is_server else os.urandom(4))))
frms[0].header.opcode = message_type
frms[0].header.rsv1 = compressed_message
for frm in frms: for frm in frms:
other_conn.send(bytes(frm)) other_conn.send(bytes(frm))
@ -105,7 +131,7 @@ class WebSocketLayer(base.Layer):
other_conn.send(bytes(frame)) other_conn.send(bytes(frame))
# close the connection # initiate close handshake
return False return False
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server): def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
@ -126,10 +152,11 @@ class WebSocketLayer(base.Layer):
client = self.client_conn.connection client = self.client_conn.connection
server = self.server_conn.connection server = self.server_conn.connection
conns = [client, server] conns = [client, server]
close_received = False
try: try:
while not self.channel.should_exit.is_set(): while not self.channel.should_exit.is_set():
r = tcp.ssl_read_select(conns, 1) r = tcp.ssl_read_select(conns, 0.1)
for conn in r: for conn in r:
source_conn = self.client_conn if conn == client else self.server_conn source_conn = self.client_conn if conn == client else self.server_conn
other_conn = self.server_conn if conn == client else self.client_conn other_conn = self.server_conn if conn == client else self.client_conn
@ -137,10 +164,15 @@ class WebSocketLayer(base.Layer):
frame = websockets.Frame.from_file(source_conn.rfile) frame = websockets.Frame.from_file(source_conn.rfile)
if not self._handle_frame(frame, source_conn, other_conn, is_server): cont = self._handle_frame(frame, source_conn, other_conn, is_server)
if not cont:
if close_received:
return return
else:
close_received = True
except (socket.error, exceptions.TcpException, SSL.Error) as e: except (socket.error, exceptions.TcpException, SSL.Error) as e:
self.flow.error = flow.Error("WebSocket connection closed unexpectedly: {}".format(repr(e))) s = 'server' if is_server else 'client'
self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e)))
self.channel.tell("websocket_error", self.flow) self.channel.tell("websocket_error", self.flow)
finally: finally:
self.channel.tell("websocket_end", self.flow) self.channel.tell("websocket_end", self.flow)

View File

@ -276,6 +276,7 @@ class TestClose(_WebSocketTest):
def handle_websockets(cls, rfile, wfile): def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile) frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame)) wfile.write(bytes(frame))
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
wfile.flush() wfile.flush()
with pytest.raises(exceptions.TcpDisconnect): with pytest.raises(exceptions.TcpDisconnect):
@ -287,6 +288,7 @@ class TestClose(_WebSocketTest):
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
client.wfile.flush() client.wfile.flush()
websockets.Frame.from_file(client.rfile)
with pytest.raises(exceptions.TcpDisconnect): with pytest.raises(exceptions.TcpDisconnect):
websockets.Frame.from_file(client.rfile) websockets.Frame.from_file(client.rfile)
@ -296,6 +298,7 @@ class TestClose(_WebSocketTest):
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
client.wfile.flush() client.wfile.flush()
websockets.Frame.from_file(client.rfile)
with pytest.raises(exceptions.TcpDisconnect): with pytest.raises(exceptions.TcpDisconnect):
websockets.Frame.from_file(client.rfile) websockets.Frame.from_file(client.rfile)
@ -305,6 +308,7 @@ class TestClose(_WebSocketTest):
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
client.wfile.flush() client.wfile.flush()
websockets.Frame.from_file(client.rfile)
with pytest.raises(exceptions.TcpDisconnect): with pytest.raises(exceptions.TcpDisconnect):
websockets.Frame.from_file(client.rfile) websockets.Frame.from_file(client.rfile)