mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
Merge pull request #1797 from Kriechi/websocket++
Pass Autobahn WebSocket Test Suite
This commit is contained in:
commit
d658783dec
@ -37,6 +37,7 @@ import sys
|
|||||||
import functools
|
import functools
|
||||||
|
|
||||||
from mitmproxy import http
|
from mitmproxy import http
|
||||||
|
from mitmproxy import websocket
|
||||||
from mitmproxy import tcp
|
from mitmproxy import tcp
|
||||||
from mitmproxy import flow
|
from mitmproxy import flow
|
||||||
|
|
||||||
@ -99,6 +100,14 @@ class FHTTP(_Action):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class FWebSocket(_Action):
|
||||||
|
code = "websocket"
|
||||||
|
help = "Match WebSocket flows"
|
||||||
|
|
||||||
|
@only(websocket.WebSocketFlow)
|
||||||
|
def __call__(self, f):
|
||||||
|
return True
|
||||||
|
|
||||||
class FTCP(_Action):
|
class FTCP(_Action):
|
||||||
code = "tcp"
|
code = "tcp"
|
||||||
help = "Match TCP flows"
|
help = "Match TCP flows"
|
||||||
@ -245,7 +254,7 @@ class FBod(_Rex):
|
|||||||
help = "Body"
|
help = "Body"
|
||||||
flags = re.DOTALL
|
flags = re.DOTALL
|
||||||
|
|
||||||
@only(http.HTTPFlow, tcp.TCPFlow)
|
@only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow)
|
||||||
def __call__(self, f):
|
def __call__(self, f):
|
||||||
if isinstance(f, http.HTTPFlow):
|
if isinstance(f, http.HTTPFlow):
|
||||||
if f.request and f.request.raw_content:
|
if f.request and f.request.raw_content:
|
||||||
@ -254,7 +263,7 @@ class FBod(_Rex):
|
|||||||
if f.response and f.response.raw_content:
|
if f.response and f.response.raw_content:
|
||||||
if self.re.search(f.response.get_content(strict=False)):
|
if self.re.search(f.response.get_content(strict=False)):
|
||||||
return True
|
return True
|
||||||
elif isinstance(f, tcp.TCPFlow):
|
elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow):
|
||||||
for msg in f.messages:
|
for msg in f.messages:
|
||||||
if self.re.search(msg.content):
|
if self.re.search(msg.content):
|
||||||
return True
|
return True
|
||||||
@ -266,13 +275,13 @@ class FBodRequest(_Rex):
|
|||||||
help = "Request body"
|
help = "Request body"
|
||||||
flags = re.DOTALL
|
flags = re.DOTALL
|
||||||
|
|
||||||
@only(http.HTTPFlow, tcp.TCPFlow)
|
@only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow)
|
||||||
def __call__(self, f):
|
def __call__(self, f):
|
||||||
if isinstance(f, http.HTTPFlow):
|
if isinstance(f, http.HTTPFlow):
|
||||||
if f.request and f.request.raw_content:
|
if f.request and f.request.raw_content:
|
||||||
if self.re.search(f.request.get_content(strict=False)):
|
if self.re.search(f.request.get_content(strict=False)):
|
||||||
return True
|
return True
|
||||||
elif isinstance(f, tcp.TCPFlow):
|
elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow):
|
||||||
for msg in f.messages:
|
for msg in f.messages:
|
||||||
if msg.from_client and self.re.search(msg.content):
|
if msg.from_client and self.re.search(msg.content):
|
||||||
return True
|
return True
|
||||||
@ -283,13 +292,13 @@ class FBodResponse(_Rex):
|
|||||||
help = "Response body"
|
help = "Response body"
|
||||||
flags = re.DOTALL
|
flags = re.DOTALL
|
||||||
|
|
||||||
@only(http.HTTPFlow, tcp.TCPFlow)
|
@only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow)
|
||||||
def __call__(self, f):
|
def __call__(self, f):
|
||||||
if isinstance(f, http.HTTPFlow):
|
if isinstance(f, http.HTTPFlow):
|
||||||
if f.response and f.response.raw_content:
|
if f.response and f.response.raw_content:
|
||||||
if self.re.search(f.response.get_content(strict=False)):
|
if self.re.search(f.response.get_content(strict=False)):
|
||||||
return True
|
return True
|
||||||
elif isinstance(f, tcp.TCPFlow):
|
elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow):
|
||||||
for msg in f.messages:
|
for msg in f.messages:
|
||||||
if not msg.from_client and self.re.search(msg.content):
|
if not msg.from_client and self.re.search(msg.content):
|
||||||
return True
|
return True
|
||||||
|
@ -59,30 +59,56 @@ class WebSocketLayer(base.Layer):
|
|||||||
fb.append(frame)
|
fb.append(frame)
|
||||||
|
|
||||||
if frame.header.fin:
|
if frame.header.fin:
|
||||||
if frame.header.opcode == websockets.OPCODE.TEXT:
|
payload = b''.join(f.payload for f in fb)
|
||||||
|
original_chunk_sizes = [len(f.payload) for f in fb]
|
||||||
|
message_type = fb[0].header.opcode
|
||||||
|
compressed_message = fb[0].header.rsv1
|
||||||
|
fb.clear()
|
||||||
|
|
||||||
|
if message_type == websockets.OPCODE.TEXT:
|
||||||
t = WebSocketTextMessage
|
t = WebSocketTextMessage
|
||||||
else:
|
else:
|
||||||
t = WebSocketBinaryMessage
|
t = WebSocketBinaryMessage
|
||||||
|
|
||||||
payload = b''.join(f.payload for f in fb)
|
|
||||||
fb.clear()
|
|
||||||
|
|
||||||
websocket_message = t(self.flow, not is_server, payload)
|
websocket_message = t(self.flow, not is_server, payload)
|
||||||
|
length = len(websocket_message.content)
|
||||||
self.flow.messages.append(websocket_message)
|
self.flow.messages.append(websocket_message)
|
||||||
self.channel.ask("websocket_message", self.flow)
|
self.channel.ask("websocket_message", self.flow)
|
||||||
|
|
||||||
# chunk payload into multiple 10kB frames, and send them
|
def get_chunk(payload):
|
||||||
payload = websocket_message.content
|
if len(payload) == length:
|
||||||
chunk_size = 10240 # 10kB
|
# message has the same length, we can reuse the same sizes
|
||||||
|
pos = 0
|
||||||
|
for s in original_chunk_sizes:
|
||||||
|
yield payload[pos:pos + s]
|
||||||
|
pos += s
|
||||||
|
else:
|
||||||
|
# just re-chunk everything into 10kB frames
|
||||||
|
chunk_size = 10240
|
||||||
chunks = range(0, len(payload), chunk_size)
|
chunks = range(0, len(payload), chunk_size)
|
||||||
|
for i in chunks:
|
||||||
|
yield payload[i:i + chunk_size]
|
||||||
|
|
||||||
frms = [
|
frms = [
|
||||||
websockets.Frame(
|
websockets.Frame(
|
||||||
payload=payload[i:i + chunk_size],
|
payload=chunk,
|
||||||
opcode=frame.header.opcode,
|
opcode=frame.header.opcode,
|
||||||
mask=(False if is_server else 1),
|
mask=(False if is_server else 1),
|
||||||
masking_key=(b'' if is_server else os.urandom(4))) for i in chunks
|
masking_key=(b'' if is_server else os.urandom(4)))
|
||||||
|
for chunk in get_chunk(websocket_message.content)
|
||||||
]
|
]
|
||||||
frms[-1].header.fin = 1
|
|
||||||
|
if len(frms) > 0:
|
||||||
|
frms[-1].header.fin = True
|
||||||
|
else:
|
||||||
|
frms.append(websockets.Frame(
|
||||||
|
fin=True,
|
||||||
|
opcode=websockets.OPCODE.CONTINUE,
|
||||||
|
mask=(False if is_server else 1),
|
||||||
|
masking_key=(b'' if is_server else os.urandom(4))))
|
||||||
|
|
||||||
|
frms[0].header.opcode = message_type
|
||||||
|
frms[0].header.rsv1 = compressed_message
|
||||||
|
|
||||||
for frm in frms:
|
for frm in frms:
|
||||||
other_conn.send(bytes(frm))
|
other_conn.send(bytes(frm))
|
||||||
@ -105,7 +131,7 @@ class WebSocketLayer(base.Layer):
|
|||||||
|
|
||||||
other_conn.send(bytes(frame))
|
other_conn.send(bytes(frame))
|
||||||
|
|
||||||
# close the connection
|
# initiate close handshake
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
|
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
|
||||||
@ -126,10 +152,11 @@ class WebSocketLayer(base.Layer):
|
|||||||
client = self.client_conn.connection
|
client = self.client_conn.connection
|
||||||
server = self.server_conn.connection
|
server = self.server_conn.connection
|
||||||
conns = [client, server]
|
conns = [client, server]
|
||||||
|
close_received = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not self.channel.should_exit.is_set():
|
while not self.channel.should_exit.is_set():
|
||||||
r = tcp.ssl_read_select(conns, 1)
|
r = tcp.ssl_read_select(conns, 0.1)
|
||||||
for conn in r:
|
for conn in r:
|
||||||
source_conn = self.client_conn if conn == client else self.server_conn
|
source_conn = self.client_conn if conn == client else self.server_conn
|
||||||
other_conn = self.server_conn if conn == client else self.client_conn
|
other_conn = self.server_conn if conn == client else self.client_conn
|
||||||
@ -137,10 +164,15 @@ class WebSocketLayer(base.Layer):
|
|||||||
|
|
||||||
frame = websockets.Frame.from_file(source_conn.rfile)
|
frame = websockets.Frame.from_file(source_conn.rfile)
|
||||||
|
|
||||||
if not self._handle_frame(frame, source_conn, other_conn, is_server):
|
cont = self._handle_frame(frame, source_conn, other_conn, is_server)
|
||||||
|
if not cont:
|
||||||
|
if close_received:
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
|
close_received = True
|
||||||
except (socket.error, exceptions.TcpException, SSL.Error) as e:
|
except (socket.error, exceptions.TcpException, SSL.Error) as e:
|
||||||
self.flow.error = flow.Error("WebSocket connection closed unexpectedly: {}".format(repr(e)))
|
s = 'server' if is_server else 'client'
|
||||||
|
self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e)))
|
||||||
self.channel.tell("websocket_error", self.flow)
|
self.channel.tell("websocket_error", self.flow)
|
||||||
finally:
|
finally:
|
||||||
self.channel.tell("websocket_end", self.flow)
|
self.channel.tell("websocket_end", self.flow)
|
||||||
|
@ -276,6 +276,7 @@ class TestClose(_WebSocketTest):
|
|||||||
def handle_websockets(cls, rfile, wfile):
|
def handle_websockets(cls, rfile, wfile):
|
||||||
frame = websockets.Frame.from_file(rfile)
|
frame = websockets.Frame.from_file(rfile)
|
||||||
wfile.write(bytes(frame))
|
wfile.write(bytes(frame))
|
||||||
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
with pytest.raises(exceptions.TcpDisconnect):
|
with pytest.raises(exceptions.TcpDisconnect):
|
||||||
@ -287,6 +288,7 @@ class TestClose(_WebSocketTest):
|
|||||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
client.wfile.flush()
|
client.wfile.flush()
|
||||||
|
|
||||||
|
websockets.Frame.from_file(client.rfile)
|
||||||
with pytest.raises(exceptions.TcpDisconnect):
|
with pytest.raises(exceptions.TcpDisconnect):
|
||||||
websockets.Frame.from_file(client.rfile)
|
websockets.Frame.from_file(client.rfile)
|
||||||
|
|
||||||
@ -296,6 +298,7 @@ class TestClose(_WebSocketTest):
|
|||||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
||||||
client.wfile.flush()
|
client.wfile.flush()
|
||||||
|
|
||||||
|
websockets.Frame.from_file(client.rfile)
|
||||||
with pytest.raises(exceptions.TcpDisconnect):
|
with pytest.raises(exceptions.TcpDisconnect):
|
||||||
websockets.Frame.from_file(client.rfile)
|
websockets.Frame.from_file(client.rfile)
|
||||||
|
|
||||||
@ -305,6 +308,7 @@ class TestClose(_WebSocketTest):
|
|||||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
||||||
client.wfile.flush()
|
client.wfile.flush()
|
||||||
|
|
||||||
|
websockets.Frame.from_file(client.rfile)
|
||||||
with pytest.raises(exceptions.TcpDisconnect):
|
with pytest.raises(exceptions.TcpDisconnect):
|
||||||
websockets.Frame.from_file(client.rfile)
|
websockets.Frame.from_file(client.rfile)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user