mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
prepare WebSocket stack to move to wsproto
This commit is contained in:
parent
8e9194c2b4
commit
130021b76d
@ -3,9 +3,14 @@ import socket
|
||||
import struct
|
||||
from OpenSSL import SSL
|
||||
|
||||
from wsproto import events
|
||||
from wsproto.connection import ConnectionType, WSConnection
|
||||
from wsproto.extensions import PerMessageDeflate
|
||||
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import flow
|
||||
from mitmproxy.proxy.protocol import base
|
||||
from mitmproxy.net import http
|
||||
from mitmproxy.net import tcp
|
||||
from mitmproxy.net import websockets
|
||||
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
|
||||
@ -44,108 +49,139 @@ class WebSocketLayer(base.Layer):
|
||||
self.client_frame_buffer = []
|
||||
self.server_frame_buffer = []
|
||||
|
||||
def _handle_frame(self, frame, source_conn, other_conn, is_server):
|
||||
if frame.header.opcode & 0x8 == 0:
|
||||
return self._handle_data_frame(frame, source_conn, other_conn, is_server)
|
||||
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
|
||||
return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
|
||||
elif frame.header.opcode == websockets.OPCODE.CLOSE:
|
||||
return self._handle_close(frame, source_conn, other_conn, is_server)
|
||||
else:
|
||||
return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
|
||||
self.connections = {} # type: Dict[object, WSConnection]
|
||||
|
||||
def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
|
||||
extensions = []
|
||||
if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
|
||||
if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
|
||||
extensions = [PerMessageDeflate.name]
|
||||
|
||||
fb = self.server_frame_buffer if is_server else self.client_frame_buffer
|
||||
fb.append(frame)
|
||||
self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
|
||||
extensions=extensions)
|
||||
self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
|
||||
host=handshake_flow.request.host,
|
||||
resource=handshake_flow.request.path,
|
||||
extensions=extensions)
|
||||
|
||||
if frame.header.fin:
|
||||
payload = b''.join(f.payload for f in fb)
|
||||
original_chunk_sizes = [len(f.payload) for f in fb]
|
||||
message_type = fb[0].header.opcode
|
||||
compressed_message = fb[0].header.rsv1
|
||||
fb.clear()
|
||||
data = self.connections[self.server_conn].bytes_to_send()
|
||||
self.connections[self.client_conn].receive_bytes(data)
|
||||
|
||||
websocket_message = WebSocketMessage(message_type, not is_server, payload)
|
||||
length = len(websocket_message.content)
|
||||
self.flow.messages.append(websocket_message)
|
||||
self.channel.ask("websocket_message", self.flow)
|
||||
event = next(self.connections[self.client_conn].events())
|
||||
assert isinstance(event, events.ConnectionRequested)
|
||||
|
||||
if not self.flow.stream:
|
||||
def get_chunk(payload):
|
||||
if len(payload) == length:
|
||||
# message has the same length, we can reuse the same sizes
|
||||
pos = 0
|
||||
for s in original_chunk_sizes:
|
||||
yield payload[pos:pos + s]
|
||||
pos += s
|
||||
else:
|
||||
# just re-chunk everything into 4kB frames
|
||||
# header len = 4 bytes without masking key and 8 bytes with masking key
|
||||
chunk_size = 4092 if is_server else 4088
|
||||
chunks = range(0, len(payload), chunk_size)
|
||||
for i in chunks:
|
||||
yield payload[i:i + chunk_size]
|
||||
self.connections[self.client_conn].accept(event)
|
||||
self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send())
|
||||
assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished)
|
||||
|
||||
frms = [
|
||||
websockets.Frame(
|
||||
payload=chunk,
|
||||
opcode=frame.header.opcode,
|
||||
mask=(False if is_server else 1),
|
||||
masking_key=(b'' if is_server else os.urandom(4)))
|
||||
for chunk in get_chunk(websocket_message.content)
|
||||
]
|
||||
|
||||
if len(frms) > 0:
|
||||
frms[-1].header.fin = True
|
||||
else:
|
||||
frms.append(websockets.Frame(
|
||||
fin=True,
|
||||
opcode=websockets.OPCODE.CONTINUE,
|
||||
mask=(False if is_server else 1),
|
||||
masking_key=(b'' if is_server else os.urandom(4))))
|
||||
|
||||
frms[0].header.opcode = message_type
|
||||
frms[0].header.rsv1 = compressed_message
|
||||
|
||||
for frm in frms:
|
||||
other_conn.send(bytes(frm))
|
||||
|
||||
else:
|
||||
other_conn.send(bytes(frame))
|
||||
|
||||
elif self.flow.stream:
|
||||
other_conn.send(bytes(frame))
|
||||
def _handle_event(self, event, source_conn, other_conn, is_server):
|
||||
if isinstance(event, events.DataReceived):
|
||||
return self._handle_data_received(event, source_conn, other_conn, is_server)
|
||||
elif isinstance(event, events.PingReceived):
|
||||
return self._handle_ping_received(event, source_conn, other_conn, is_server)
|
||||
elif isinstance(event, events.PongReceived):
|
||||
return self._handle_pong_received(event, source_conn, other_conn, is_server)
|
||||
elif isinstance(event, events.ConnectionFailed):
|
||||
return self._handle_connection_closed(event, source_conn, other_conn, is_server)
|
||||
elif isinstance(event, events.ConnectionFailed):
|
||||
return self._handle_connection_failed(event)
|
||||
|
||||
# fail-safe for unhandled events
|
||||
return True
|
||||
|
||||
def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
|
||||
# just forward the ping/pong to the other side
|
||||
other_conn.send(bytes(frame))
|
||||
def _handle_data_received(self, event, source_conn, other_conn, is_server):
|
||||
return True
|
||||
|
||||
def _handle_close(self, frame, source_conn, other_conn, is_server):
|
||||
def _handle_ping_received(self, event, source_conn, other_conn, is_server):
|
||||
# PING is automatically answered with a PONG by wsproto
|
||||
# TODO: log this PING and its payload
|
||||
self.connections[other_conn].ping(event.payload)
|
||||
other_conn.send(self.connections[other_conn].bytes_to_send())
|
||||
return True
|
||||
|
||||
def _handle_pong_received(self, event, source_conn, other_conn, is_server):
|
||||
# TODO: log this PONG and its payload
|
||||
self.connections[other_conn].pong(event.payload)
|
||||
other_conn.send(self.connections[other_conn].bytes_to_send())
|
||||
return True
|
||||
|
||||
def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
|
||||
self.flow.close_sender = "server" if is_server else "client"
|
||||
if len(frame.payload) >= 2:
|
||||
code, = struct.unpack('!H', frame.payload[:2])
|
||||
self.flow.close_code = code
|
||||
self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
|
||||
if len(frame.payload) > 2:
|
||||
self.flow.close_reason = frame.payload[2:]
|
||||
self.flow.close_code = event.code
|
||||
self.flow.close_reason = event.reason
|
||||
|
||||
other_conn.send(bytes(frame))
|
||||
print(self.connections[other_conn])
|
||||
self.connections[other_conn].close(event.code, event.reason)
|
||||
|
||||
# initiate close handshake
|
||||
return False
|
||||
|
||||
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
|
||||
# unknown frame - just forward it
|
||||
other_conn.send(bytes(frame))
|
||||
def _handle_connection_failed(self, event):
|
||||
raise exceptions.TcpException(repr(event))
|
||||
|
||||
sender = "server" if is_server else "client"
|
||||
self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
|
||||
|
||||
return True
|
||||
# def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
|
||||
#
|
||||
# fb = self.server_frame_buffer if is_server else self.client_frame_buffer
|
||||
# fb.append(frame)
|
||||
#
|
||||
# if frame.header.fin:
|
||||
# payload = b''.join(f.payload for f in fb)
|
||||
# original_chunk_sizes = [len(f.payload) for f in fb]
|
||||
# message_type = fb[0].header.opcode
|
||||
# compressed_message = fb[0].header.rsv1
|
||||
# fb.clear()
|
||||
#
|
||||
# websocket_message = WebSocketMessage(message_type, not is_server, payload)
|
||||
# length = len(websocket_message.content)
|
||||
# self.flow.messages.append(websocket_message)
|
||||
# self.channel.ask("websocket_message", self.flow)
|
||||
#
|
||||
# if not self.flow.stream:
|
||||
# def get_chunk(payload):
|
||||
# if len(payload) == length:
|
||||
# # message has the same length, we can reuse the same sizes
|
||||
# pos = 0
|
||||
# for s in original_chunk_sizes:
|
||||
# yield payload[pos:pos + s]
|
||||
# pos += s
|
||||
# else:
|
||||
# # just re-chunk everything into 4kB frames
|
||||
# # header len = 4 bytes without masking key and 8 bytes with masking key
|
||||
# chunk_size = 4092 if is_server else 4088
|
||||
# chunks = range(0, len(payload), chunk_size)
|
||||
# for i in chunks:
|
||||
# yield payload[i:i + chunk_size]
|
||||
#
|
||||
# frms = [
|
||||
# websockets.Frame(
|
||||
# payload=chunk,
|
||||
# opcode=frame.header.opcode,
|
||||
# mask=(False if is_server else 1),
|
||||
# masking_key=(b'' if is_server else os.urandom(4)))
|
||||
# for chunk in get_chunk(websocket_message.content)
|
||||
# ]
|
||||
#
|
||||
# if len(frms) > 0:
|
||||
# frms[-1].header.fin = True
|
||||
# else:
|
||||
# frms.append(websockets.Frame(
|
||||
# fin=True,
|
||||
# opcode=websockets.OPCODE.CONTINUE,
|
||||
# mask=(False if is_server else 1),
|
||||
# masking_key=(b'' if is_server else os.urandom(4))))
|
||||
#
|
||||
# frms[0].header.opcode = message_type
|
||||
# frms[0].header.rsv1 = compressed_message
|
||||
#
|
||||
# for frm in frms:
|
||||
# other_conn.send(bytes(frm))
|
||||
#
|
||||
# else:
|
||||
# other_conn.send(bytes(frame))
|
||||
#
|
||||
# elif self.flow.stream:
|
||||
# other_conn.send(bytes(frame))
|
||||
#
|
||||
# return True
|
||||
|
||||
def __call__(self):
|
||||
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
|
||||
@ -153,27 +189,28 @@ class WebSocketLayer(base.Layer):
|
||||
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
|
||||
self.channel.ask("websocket_start", self.flow)
|
||||
|
||||
client = self.client_conn.connection
|
||||
server = self.server_conn.connection
|
||||
conns = [client, server]
|
||||
conns = [c.connection for c in self.connections.keys()]
|
||||
close_received = False
|
||||
|
||||
try:
|
||||
while not self.channel.should_exit.is_set():
|
||||
r = tcp.ssl_read_select(conns, 0.1)
|
||||
for conn in r:
|
||||
source_conn = self.client_conn if conn == client else self.server_conn
|
||||
other_conn = self.server_conn if conn == client else self.client_conn
|
||||
is_server = (conn == self.server_conn.connection)
|
||||
source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
|
||||
other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
|
||||
is_server = (source_conn == self.server_conn)
|
||||
|
||||
frame = websockets.Frame.from_file(source_conn.rfile)
|
||||
self.connections[source_conn].receive_bytes(bytes(frame))
|
||||
source_conn.send(self.connections[source_conn].bytes_to_send())
|
||||
|
||||
cont = self._handle_frame(frame, source_conn, other_conn, is_server)
|
||||
if not cont:
|
||||
if close_received:
|
||||
return
|
||||
else:
|
||||
close_received = True
|
||||
for event in self.connections[source_conn].events():
|
||||
print('is_server:', is_server, 'event:', event)
|
||||
if not self._handle_event(event, source_conn, other_conn, is_server):
|
||||
if close_received:
|
||||
break
|
||||
else:
|
||||
close_received = True
|
||||
except (socket.error, exceptions.TcpException, SSL.Error) as e:
|
||||
s = 'server' if is_server else 'client'
|
||||
self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e)))
|
||||
|
@ -164,19 +164,19 @@ class TestSimple(_WebSocketTest):
|
||||
frame = websockets.Frame.from_file(self.client.rfile)
|
||||
assert frame.payload == b'server-foobar'
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
||||
self.client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(self.client.rfile)
|
||||
assert frame.payload == b'self.client-foobar'
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
|
||||
self.client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(self.client.rfile)
|
||||
assert frame.payload == b'\xde\xad\xbe\xef'
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
self.client.wfile.flush()
|
||||
|
||||
assert len(self.master.state.flows) == 2
|
||||
@ -213,13 +213,13 @@ class TestSimpleTLS(_WebSocketTest):
|
||||
frame = websockets.Frame.from_file(self.client.rfile)
|
||||
assert frame.payload == b'server-foobar'
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
||||
self.client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(self.client.rfile)
|
||||
assert frame.payload == b'self.client-foobar'
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
self.client.wfile.flush()
|
||||
|
||||
|
||||
@ -234,7 +234,7 @@ class TestPing(_WebSocketTest):
|
||||
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||
assert frame.payload == b'foobar'
|
||||
|
||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
|
||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done')))
|
||||
wfile.flush()
|
||||
|
||||
def test_ping(self):
|
||||
@ -244,12 +244,12 @@ class TestPing(_WebSocketTest):
|
||||
assert frame.header.opcode == websockets.OPCODE.PING
|
||||
assert frame.payload == b'foobar'
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
||||
self.client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(self.client.rfile)
|
||||
assert frame.header.opcode == websockets.OPCODE.TEXT
|
||||
assert frame.payload == b'pong-received'
|
||||
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||
assert frame.payload == b'done'
|
||||
|
||||
|
||||
class TestPong(_WebSocketTest):
|
||||
@ -266,7 +266,7 @@ class TestPong(_WebSocketTest):
|
||||
def test_pong(self):
|
||||
self.setup_connection()
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
||||
self.client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(self.client.rfile)
|
||||
@ -289,7 +289,7 @@ class TestClose(_WebSocketTest):
|
||||
def test_close(self):
|
||||
self.setup_connection()
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
self.client.wfile.flush()
|
||||
|
||||
websockets.Frame.from_file(self.client.rfile)
|
||||
@ -299,7 +299,7 @@ class TestClose(_WebSocketTest):
|
||||
def test_close_payload_1(self):
|
||||
self.setup_connection()
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
||||
self.client.wfile.flush()
|
||||
|
||||
websockets.Frame.from_file(self.client.rfile)
|
||||
@ -309,7 +309,7 @@ class TestClose(_WebSocketTest):
|
||||
def test_close_payload_2(self):
|
||||
self.setup_connection()
|
||||
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
||||
self.client.wfile.flush()
|
||||
|
||||
websockets.Frame.from_file(self.client.rfile)
|
||||
|
Loading…
Reference in New Issue
Block a user