Merge pull request #3459 from RPigott/wsproto-update

Fix: update to wsproto 0.13 for real this time
This commit is contained in:
Thomas Kriechbaumer 2019-01-31 22:08:02 +01:00 committed by GitHub
commit 70777a1b6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 48 deletions

View File

@ -4,8 +4,9 @@ from OpenSSL import SSL
import wsproto import wsproto
from wsproto import events from wsproto import events, WSConnection
from wsproto.connection import ConnectionType, WSConnection from wsproto.connection import ConnectionType
from wsproto.events import AcceptConnection, CloseConnection, Message, Ping, Request
from wsproto.extensions import PerMessageDeflate from wsproto.extensions import PerMessageDeflate
from mitmproxy import exceptions from mitmproxy import exceptions
@ -52,51 +53,52 @@ class WebSocketLayer(base.Layer):
self.connections: dict[object, WSConnection] = {} self.connections: dict[object, WSConnection] = {}
extensions = [] client_extensions = []
server_extensions = []
if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers: if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']: if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
extensions = [PerMessageDeflate()] client_extensions = [PerMessageDeflate()]
self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER, server_extensions = [PerMessageDeflate()]
extensions=extensions) self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER)
self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT, self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT)
host=handshake_flow.request.host,
resource=handshake_flow.request.path,
extensions=extensions)
if extensions:
for conn in self.connections.values():
conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions'])
data = self.connections[self.server_conn].bytes_to_send() if client_extensions:
self.connections[self.client_conn].receive_bytes(data) client_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
if server_extensions:
server_extensions[0].finalize(handshake_flow.response.headers['Sec-WebSocket-Extensions'])
request = Request(extensions=client_extensions, host=handshake_flow.request.host, target=handshake_flow.request.path)
data = self.connections[self.server_conn].send(request)
self.connections[self.client_conn].receive_data(data)
event = next(self.connections[self.client_conn].events()) event = next(self.connections[self.client_conn].events())
assert isinstance(event, events.ConnectionRequested) assert isinstance(event, events.Request)
self.connections[self.client_conn].accept(event) data = self.connections[self.client_conn].send(AcceptConnection(extensions=server_extensions))
self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send()) self.connections[self.server_conn].receive_data(data)
assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished) assert isinstance(next(self.connections[self.server_conn].events()), events.AcceptConnection)
def _handle_event(self, event, source_conn, other_conn, is_server): def _handle_event(self, event, source_conn, other_conn, is_server):
if isinstance(event, events.DataReceived): if isinstance(event, events.Message):
return self._handle_data_received(event, source_conn, other_conn, is_server) return self._handle_message(event, source_conn, other_conn, is_server)
elif isinstance(event, events.PingReceived): elif isinstance(event, events.Ping):
return self._handle_ping_received(event, source_conn, other_conn, is_server) return self._handle_ping(event, source_conn, other_conn, is_server)
elif isinstance(event, events.PongReceived): elif isinstance(event, events.Pong):
return self._handle_pong_received(event, source_conn, other_conn, is_server) return self._handle_pong(event, source_conn, other_conn, is_server)
elif isinstance(event, events.ConnectionClosed): elif isinstance(event, events.CloseConnection):
return self._handle_connection_closed(event, source_conn, other_conn, is_server) return self._handle_close_connection(event, source_conn, other_conn, is_server)
# fail-safe for unhandled events # fail-safe for unhandled events
return True # pragma: no cover return True # pragma: no cover
def _handle_data_received(self, event, source_conn, other_conn, is_server): def _handle_message(self, event, source_conn, other_conn, is_server):
fb = self.server_frame_buffer if is_server else self.client_frame_buffer fb = self.server_frame_buffer if is_server else self.client_frame_buffer
fb.append(event.data) fb.append(event.data)
if event.message_finished: if event.message_finished:
original_chunk_sizes = [len(f) for f in fb] original_chunk_sizes = [len(f) for f in fb]
if isinstance(event, events.TextReceived): if isinstance(event, events.TextMessage):
message_type = wsproto.frame_protocol.Opcode.TEXT message_type = wsproto.frame_protocol.Opcode.TEXT
payload = ''.join(fb) payload = ''.join(fb)
else: else:
@ -127,19 +129,20 @@ class WebSocketLayer(base.Layer):
yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False) yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
for chunk, final in get_chunk(websocket_message.content): for chunk, final in get_chunk(websocket_message.content):
self.connections[other_conn].send_data(chunk, final) data = self.connections[other_conn].send(Message(data=chunk, message_finished=final))
other_conn.send(self.connections[other_conn].bytes_to_send()) other_conn.send(data)
if self.flow.stream: if self.flow.stream:
self.connections[other_conn].send_data(event.data, event.message_finished) data = self.connections[other_conn].send(Message(data=event.data, message_finished=event.message_finished))
other_conn.send(self.connections[other_conn].bytes_to_send()) other_conn.send(data)
return True return True
def _handle_ping_received(self, event, source_conn, other_conn, is_server): def _handle_ping(self, event, source_conn, other_conn, is_server):
# PING is automatically answered with a PONG by wsproto # Use event.response to create the approprate Pong response
self.connections[other_conn].ping() data = self.connections[other_conn].send(Ping())
other_conn.send(self.connections[other_conn].bytes_to_send()) other_conn.send(data)
source_conn.send(self.connections[source_conn].bytes_to_send()) data = self.connections[source_conn].send(event.response())
source_conn.send(data)
self.log( self.log(
"Ping Received from {}".format("server" if is_server else "client"), "Ping Received from {}".format("server" if is_server else "client"),
"info", "info",
@ -147,7 +150,7 @@ class WebSocketLayer(base.Layer):
) )
return True return True
def _handle_pong_received(self, event, source_conn, other_conn, is_server): def _handle_pong(self, event, source_conn, other_conn, is_server):
self.log( self.log(
"Pong Received from {}".format("server" if is_server else "client"), "Pong Received from {}".format("server" if is_server else "client"),
"info", "info",
@ -155,14 +158,15 @@ class WebSocketLayer(base.Layer):
) )
return True return True
def _handle_connection_closed(self, event, source_conn, other_conn, is_server): def _handle_close_connection(self, event, source_conn, other_conn, is_server):
self.flow.close_sender = "server" if is_server else "client" self.flow.close_sender = "server" if is_server else "client"
self.flow.close_code = event.code self.flow.close_code = event.code
self.flow.close_reason = event.reason self.flow.close_reason = event.reason
self.connections[other_conn].close(event.code, event.reason) data = self.connections[other_conn].send(CloseConnection(code=event.code, reason=event.reason))
other_conn.send(self.connections[other_conn].bytes_to_send()) other_conn.send(data)
source_conn.send(self.connections[source_conn].bytes_to_send()) data = self.connections[source_conn].send(event.response())
source_conn.send(data)
return False return False
@ -170,8 +174,7 @@ class WebSocketLayer(base.Layer):
while True: while True:
try: try:
payload = message_queue.get_nowait() payload = message_queue.get_nowait()
self.connections[endpoint].send_data(payload, final=True) data = self.connections[endpoint].send(Message(data=payload, message_finished=True))
data = self.connections[endpoint].bytes_to_send()
endpoint.send(data) endpoint.send(data)
except queue.Empty: except queue.Empty:
break break
@ -197,8 +200,8 @@ class WebSocketLayer(base.Layer):
is_server = (source_conn == self.server_conn) is_server = (source_conn == self.server_conn)
frame = websockets.Frame.from_file(source_conn.rfile) frame = websockets.Frame.from_file(source_conn.rfile)
self.connections[source_conn].receive_bytes(bytes(frame)) data = self.connections[source_conn].receive_data(bytes(frame))
source_conn.send(self.connections[source_conn].bytes_to_send()) source_conn.send(data)
if close_received: if close_received:
return return

View File

@ -80,7 +80,7 @@ setup(
"sortedcontainers>=1.5.4,<2.1", "sortedcontainers>=1.5.4,<2.1",
"tornado>=4.3,<5.2", "tornado>=4.3,<5.2",
"urwid>=2.0.1,<2.1", "urwid>=2.0.1,<2.1",
"wsproto>=0.12.0,<0.13.0", "wsproto>=0.13.0,<0.14.0",
], ],
extras_require={ extras_require={
':sys_platform == "win32"': [ ':sys_platform == "win32"': [