mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
prepare WebSocket stack to move to wsproto
This commit is contained in:
parent
8e9194c2b4
commit
130021b76d
@ -3,9 +3,14 @@ import socket
|
|||||||
import struct
|
import struct
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
|
from wsproto import events
|
||||||
|
from wsproto.connection import ConnectionType, WSConnection
|
||||||
|
from wsproto.extensions import PerMessageDeflate
|
||||||
|
|
||||||
from mitmproxy import exceptions
|
from mitmproxy import exceptions
|
||||||
from mitmproxy import flow
|
from mitmproxy import flow
|
||||||
from mitmproxy.proxy.protocol import base
|
from mitmproxy.proxy.protocol import base
|
||||||
|
from mitmproxy.net import http
|
||||||
from mitmproxy.net import tcp
|
from mitmproxy.net import tcp
|
||||||
from mitmproxy.net import websockets
|
from mitmproxy.net import websockets
|
||||||
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
|
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
|
||||||
@ -44,108 +49,139 @@ class WebSocketLayer(base.Layer):
|
|||||||
self.client_frame_buffer = []
|
self.client_frame_buffer = []
|
||||||
self.server_frame_buffer = []
|
self.server_frame_buffer = []
|
||||||
|
|
||||||
def _handle_frame(self, frame, source_conn, other_conn, is_server):
|
self.connections = {} # type: Dict[object, WSConnection]
|
||||||
if frame.header.opcode & 0x8 == 0:
|
|
||||||
return self._handle_data_frame(frame, source_conn, other_conn, is_server)
|
|
||||||
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
|
|
||||||
return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
|
|
||||||
elif frame.header.opcode == websockets.OPCODE.CLOSE:
|
|
||||||
return self._handle_close(frame, source_conn, other_conn, is_server)
|
|
||||||
else:
|
|
||||||
return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
|
|
||||||
|
|
||||||
def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
|
extensions = []
|
||||||
|
if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
|
||||||
|
if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
|
||||||
|
extensions = [PerMessageDeflate.name]
|
||||||
|
|
||||||
fb = self.server_frame_buffer if is_server else self.client_frame_buffer
|
self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
|
||||||
fb.append(frame)
|
extensions=extensions)
|
||||||
|
self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
|
||||||
|
host=handshake_flow.request.host,
|
||||||
|
resource=handshake_flow.request.path,
|
||||||
|
extensions=extensions)
|
||||||
|
|
||||||
if frame.header.fin:
|
data = self.connections[self.server_conn].bytes_to_send()
|
||||||
payload = b''.join(f.payload for f in fb)
|
self.connections[self.client_conn].receive_bytes(data)
|
||||||
original_chunk_sizes = [len(f.payload) for f in fb]
|
|
||||||
message_type = fb[0].header.opcode
|
|
||||||
compressed_message = fb[0].header.rsv1
|
|
||||||
fb.clear()
|
|
||||||
|
|
||||||
websocket_message = WebSocketMessage(message_type, not is_server, payload)
|
event = next(self.connections[self.client_conn].events())
|
||||||
length = len(websocket_message.content)
|
assert isinstance(event, events.ConnectionRequested)
|
||||||
self.flow.messages.append(websocket_message)
|
|
||||||
self.channel.ask("websocket_message", self.flow)
|
|
||||||
|
|
||||||
if not self.flow.stream:
|
self.connections[self.client_conn].accept(event)
|
||||||
def get_chunk(payload):
|
self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send())
|
||||||
if len(payload) == length:
|
assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished)
|
||||||
# message has the same length, we can reuse the same sizes
|
|
||||||
pos = 0
|
|
||||||
for s in original_chunk_sizes:
|
|
||||||
yield payload[pos:pos + s]
|
|
||||||
pos += s
|
|
||||||
else:
|
|
||||||
# just re-chunk everything into 4kB frames
|
|
||||||
# header len = 4 bytes without masking key and 8 bytes with masking key
|
|
||||||
chunk_size = 4092 if is_server else 4088
|
|
||||||
chunks = range(0, len(payload), chunk_size)
|
|
||||||
for i in chunks:
|
|
||||||
yield payload[i:i + chunk_size]
|
|
||||||
|
|
||||||
frms = [
|
def _handle_event(self, event, source_conn, other_conn, is_server):
|
||||||
websockets.Frame(
|
if isinstance(event, events.DataReceived):
|
||||||
payload=chunk,
|
return self._handle_data_received(event, source_conn, other_conn, is_server)
|
||||||
opcode=frame.header.opcode,
|
elif isinstance(event, events.PingReceived):
|
||||||
mask=(False if is_server else 1),
|
return self._handle_ping_received(event, source_conn, other_conn, is_server)
|
||||||
masking_key=(b'' if is_server else os.urandom(4)))
|
elif isinstance(event, events.PongReceived):
|
||||||
for chunk in get_chunk(websocket_message.content)
|
return self._handle_pong_received(event, source_conn, other_conn, is_server)
|
||||||
]
|
elif isinstance(event, events.ConnectionFailed):
|
||||||
|
return self._handle_connection_closed(event, source_conn, other_conn, is_server)
|
||||||
if len(frms) > 0:
|
elif isinstance(event, events.ConnectionFailed):
|
||||||
frms[-1].header.fin = True
|
return self._handle_connection_failed(event)
|
||||||
else:
|
|
||||||
frms.append(websockets.Frame(
|
|
||||||
fin=True,
|
|
||||||
opcode=websockets.OPCODE.CONTINUE,
|
|
||||||
mask=(False if is_server else 1),
|
|
||||||
masking_key=(b'' if is_server else os.urandom(4))))
|
|
||||||
|
|
||||||
frms[0].header.opcode = message_type
|
|
||||||
frms[0].header.rsv1 = compressed_message
|
|
||||||
|
|
||||||
for frm in frms:
|
|
||||||
other_conn.send(bytes(frm))
|
|
||||||
|
|
||||||
else:
|
|
||||||
other_conn.send(bytes(frame))
|
|
||||||
|
|
||||||
elif self.flow.stream:
|
|
||||||
other_conn.send(bytes(frame))
|
|
||||||
|
|
||||||
|
# fail-safe for unhandled events
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
|
def _handle_data_received(self, event, source_conn, other_conn, is_server):
|
||||||
# just forward the ping/pong to the other side
|
|
||||||
other_conn.send(bytes(frame))
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_close(self, frame, source_conn, other_conn, is_server):
|
def _handle_ping_received(self, event, source_conn, other_conn, is_server):
|
||||||
|
# PING is automatically answered with a PONG by wsproto
|
||||||
|
# TODO: log this PING and its payload
|
||||||
|
self.connections[other_conn].ping(event.payload)
|
||||||
|
other_conn.send(self.connections[other_conn].bytes_to_send())
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _handle_pong_received(self, event, source_conn, other_conn, is_server):
|
||||||
|
# TODO: log this PONG and its payload
|
||||||
|
self.connections[other_conn].pong(event.payload)
|
||||||
|
other_conn.send(self.connections[other_conn].bytes_to_send())
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
|
||||||
self.flow.close_sender = "server" if is_server else "client"
|
self.flow.close_sender = "server" if is_server else "client"
|
||||||
if len(frame.payload) >= 2:
|
self.flow.close_code = event.code
|
||||||
code, = struct.unpack('!H', frame.payload[:2])
|
self.flow.close_reason = event.reason
|
||||||
self.flow.close_code = code
|
|
||||||
self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
|
|
||||||
if len(frame.payload) > 2:
|
|
||||||
self.flow.close_reason = frame.payload[2:]
|
|
||||||
|
|
||||||
other_conn.send(bytes(frame))
|
print(self.connections[other_conn])
|
||||||
|
self.connections[other_conn].close(event.code, event.reason)
|
||||||
|
|
||||||
# initiate close handshake
|
# initiate close handshake
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
|
def _handle_connection_failed(self, event):
|
||||||
# unknown frame - just forward it
|
raise exceptions.TcpException(repr(event))
|
||||||
other_conn.send(bytes(frame))
|
|
||||||
|
|
||||||
sender = "server" if is_server else "client"
|
# def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
|
||||||
self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
|
#
|
||||||
|
# fb = self.server_frame_buffer if is_server else self.client_frame_buffer
|
||||||
return True
|
# fb.append(frame)
|
||||||
|
#
|
||||||
|
# if frame.header.fin:
|
||||||
|
# payload = b''.join(f.payload for f in fb)
|
||||||
|
# original_chunk_sizes = [len(f.payload) for f in fb]
|
||||||
|
# message_type = fb[0].header.opcode
|
||||||
|
# compressed_message = fb[0].header.rsv1
|
||||||
|
# fb.clear()
|
||||||
|
#
|
||||||
|
# websocket_message = WebSocketMessage(message_type, not is_server, payload)
|
||||||
|
# length = len(websocket_message.content)
|
||||||
|
# self.flow.messages.append(websocket_message)
|
||||||
|
# self.channel.ask("websocket_message", self.flow)
|
||||||
|
#
|
||||||
|
# if not self.flow.stream:
|
||||||
|
# def get_chunk(payload):
|
||||||
|
# if len(payload) == length:
|
||||||
|
# # message has the same length, we can reuse the same sizes
|
||||||
|
# pos = 0
|
||||||
|
# for s in original_chunk_sizes:
|
||||||
|
# yield payload[pos:pos + s]
|
||||||
|
# pos += s
|
||||||
|
# else:
|
||||||
|
# # just re-chunk everything into 4kB frames
|
||||||
|
# # header len = 4 bytes without masking key and 8 bytes with masking key
|
||||||
|
# chunk_size = 4092 if is_server else 4088
|
||||||
|
# chunks = range(0, len(payload), chunk_size)
|
||||||
|
# for i in chunks:
|
||||||
|
# yield payload[i:i + chunk_size]
|
||||||
|
#
|
||||||
|
# frms = [
|
||||||
|
# websockets.Frame(
|
||||||
|
# payload=chunk,
|
||||||
|
# opcode=frame.header.opcode,
|
||||||
|
# mask=(False if is_server else 1),
|
||||||
|
# masking_key=(b'' if is_server else os.urandom(4)))
|
||||||
|
# for chunk in get_chunk(websocket_message.content)
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# if len(frms) > 0:
|
||||||
|
# frms[-1].header.fin = True
|
||||||
|
# else:
|
||||||
|
# frms.append(websockets.Frame(
|
||||||
|
# fin=True,
|
||||||
|
# opcode=websockets.OPCODE.CONTINUE,
|
||||||
|
# mask=(False if is_server else 1),
|
||||||
|
# masking_key=(b'' if is_server else os.urandom(4))))
|
||||||
|
#
|
||||||
|
# frms[0].header.opcode = message_type
|
||||||
|
# frms[0].header.rsv1 = compressed_message
|
||||||
|
#
|
||||||
|
# for frm in frms:
|
||||||
|
# other_conn.send(bytes(frm))
|
||||||
|
#
|
||||||
|
# else:
|
||||||
|
# other_conn.send(bytes(frame))
|
||||||
|
#
|
||||||
|
# elif self.flow.stream:
|
||||||
|
# other_conn.send(bytes(frame))
|
||||||
|
#
|
||||||
|
# return True
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
|
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
|
||||||
@ -153,25 +189,26 @@ class WebSocketLayer(base.Layer):
|
|||||||
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
|
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
|
||||||
self.channel.ask("websocket_start", self.flow)
|
self.channel.ask("websocket_start", self.flow)
|
||||||
|
|
||||||
client = self.client_conn.connection
|
conns = [c.connection for c in self.connections.keys()]
|
||||||
server = self.server_conn.connection
|
|
||||||
conns = [client, server]
|
|
||||||
close_received = False
|
close_received = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not self.channel.should_exit.is_set():
|
while not self.channel.should_exit.is_set():
|
||||||
r = tcp.ssl_read_select(conns, 0.1)
|
r = tcp.ssl_read_select(conns, 0.1)
|
||||||
for conn in r:
|
for conn in r:
|
||||||
source_conn = self.client_conn if conn == client else self.server_conn
|
source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
|
||||||
other_conn = self.server_conn if conn == client else self.client_conn
|
other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
|
||||||
is_server = (conn == self.server_conn.connection)
|
is_server = (source_conn == self.server_conn)
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(source_conn.rfile)
|
frame = websockets.Frame.from_file(source_conn.rfile)
|
||||||
|
self.connections[source_conn].receive_bytes(bytes(frame))
|
||||||
|
source_conn.send(self.connections[source_conn].bytes_to_send())
|
||||||
|
|
||||||
cont = self._handle_frame(frame, source_conn, other_conn, is_server)
|
for event in self.connections[source_conn].events():
|
||||||
if not cont:
|
print('is_server:', is_server, 'event:', event)
|
||||||
|
if not self._handle_event(event, source_conn, other_conn, is_server):
|
||||||
if close_received:
|
if close_received:
|
||||||
return
|
break
|
||||||
else:
|
else:
|
||||||
close_received = True
|
close_received = True
|
||||||
except (socket.error, exceptions.TcpException, SSL.Error) as e:
|
except (socket.error, exceptions.TcpException, SSL.Error) as e:
|
||||||
|
@ -164,19 +164,19 @@ class TestSimple(_WebSocketTest):
|
|||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'server-foobar'
|
assert frame.payload == b'server-foobar'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'self.client-foobar'
|
assert frame.payload == b'self.client-foobar'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'\xde\xad\xbe\xef'
|
assert frame.payload == b'\xde\xad\xbe\xef'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
assert len(self.master.state.flows) == 2
|
assert len(self.master.state.flows) == 2
|
||||||
@ -213,13 +213,13 @@ class TestSimpleTLS(_WebSocketTest):
|
|||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'server-foobar'
|
assert frame.payload == b'server-foobar'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'self.client-foobar'
|
assert frame.payload == b'self.client-foobar'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
|
||||||
@ -234,7 +234,7 @@ class TestPing(_WebSocketTest):
|
|||||||
assert frame.header.opcode == websockets.OPCODE.PONG
|
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||||
assert frame.payload == b'foobar'
|
assert frame.payload == b'foobar'
|
||||||
|
|
||||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done')))
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
def test_ping(self):
|
def test_ping(self):
|
||||||
@ -244,12 +244,12 @@ class TestPing(_WebSocketTest):
|
|||||||
assert frame.header.opcode == websockets.OPCODE.PING
|
assert frame.header.opcode == websockets.OPCODE.PING
|
||||||
assert frame.payload == b'foobar'
|
assert frame.payload == b'foobar'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.header.opcode == websockets.OPCODE.TEXT
|
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||||
assert frame.payload == b'pong-received'
|
assert frame.payload == b'done'
|
||||||
|
|
||||||
|
|
||||||
class TestPong(_WebSocketTest):
|
class TestPong(_WebSocketTest):
|
||||||
@ -266,7 +266,7 @@ class TestPong(_WebSocketTest):
|
|||||||
def test_pong(self):
|
def test_pong(self):
|
||||||
self.setup_connection()
|
self.setup_connection()
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
@ -289,7 +289,7 @@ class TestClose(_WebSocketTest):
|
|||||||
def test_close(self):
|
def test_close(self):
|
||||||
self.setup_connection()
|
self.setup_connection()
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
websockets.Frame.from_file(self.client.rfile)
|
websockets.Frame.from_file(self.client.rfile)
|
||||||
@ -299,7 +299,7 @@ class TestClose(_WebSocketTest):
|
|||||||
def test_close_payload_1(self):
|
def test_close_payload_1(self):
|
||||||
self.setup_connection()
|
self.setup_connection()
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
websockets.Frame.from_file(self.client.rfile)
|
websockets.Frame.from_file(self.client.rfile)
|
||||||
@ -309,7 +309,7 @@ class TestClose(_WebSocketTest):
|
|||||||
def test_close_payload_2(self):
|
def test_close_payload_2(self):
|
||||||
self.setup_connection()
|
self.setup_connection()
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
websockets.Frame.from_file(self.client.rfile)
|
websockets.Frame.from_file(self.client.rfile)
|
||||||
|
Loading…
Reference in New Issue
Block a user