Merge pull request #3119 from Kriechi/inject-websocket-message

websocket: inject messages via flow
This commit is contained in:
Aldo Cortesi 2018-05-17 08:35:40 +12:00 committed by GitHub
commit 48d7a944bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 122 additions and 1 deletions

View File

@ -0,0 +1,23 @@
"""
This example shows how to inject a WebSocket message to the client.
Every new WebSocket connection will trigger a new asyncio task that
periodically injects a new message to the client.
"""
import asyncio
import mitmproxy.websocket
class InjectWebSocketMessage:
async def inject(self, flow: mitmproxy.websocket.WebSocketFlow):
i = 0
while not flow.ended and not flow.error:
await asyncio.sleep(5)
flow.inject_message(flow.client_conn, 'This is the #{} an injected message!'.format(i))
i += 1
def websocket_start(self, flow):
asyncio.get_event_loop().create_task(self.inject(flow))
addons = [InjectWebSocketMessage()]

View File

@ -1,3 +1,4 @@
import queue
import socket import socket
from OpenSSL import SSL from OpenSSL import SSL
@ -165,8 +166,18 @@ class WebSocketLayer(base.Layer):
return False return False
def _inject_messages(self, endpoint, message_queue):
while True:
try:
payload = message_queue.get_nowait()
self.connections[endpoint].send_data(payload, final=True)
data = self.connections[endpoint].bytes_to_send()
endpoint.send(data)
except queue.Empty:
break
def __call__(self): def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self) self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow)
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
self.handshake_flow.metadata['websocket_flow'] = self.flow.id self.handshake_flow.metadata['websocket_flow'] = self.flow.id
self.channel.ask("websocket_start", self.flow) self.channel.ask("websocket_start", self.flow)
@ -176,6 +187,9 @@ class WebSocketLayer(base.Layer):
try: try:
while not self.channel.should_exit.is_set(): while not self.channel.should_exit.is_set():
self._inject_messages(self.client_conn, self.flow._inject_messages_client)
self._inject_messages(self.server_conn, self.flow._inject_messages_server)
r = tcp.ssl_read_select(conns, 0.1) r = tcp.ssl_read_select(conns, 0.1)
for conn in r: for conn in r:
source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
@ -198,4 +212,5 @@ class WebSocketLayer(base.Layer):
self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e))) self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e)))
self.channel.tell("websocket_error", self.flow) self.channel.tell("websocket_error", self.flow)
finally: finally:
self.flow.ended = True
self.channel.tell("websocket_end", self.flow) self.channel.tell("websocket_end", self.flow)

View File

@ -1,4 +1,5 @@
import time import time
import queue
from typing import List, Optional from typing import List, Optional
from wsproto.frame_protocol import CloseReason from wsproto.frame_protocol import CloseReason
@ -77,6 +78,11 @@ class WebSocketFlow(flow.Flow):
"""True of this connection is streaming directly to the other endpoint.""" """True of this connection is streaming directly to the other endpoint."""
self.handshake_flow = handshake_flow self.handshake_flow = handshake_flow
"""The HTTP flow containing the initial WebSocket handshake.""" """The HTTP flow containing the initial WebSocket handshake."""
self.ended = False
"""True when the WebSocket connection has been closed."""
self._inject_messages_client = queue.Queue(maxsize=1)
self._inject_messages_server = queue.Queue(maxsize=1)
if handshake_flow: if handshake_flow:
self.client_key = websockets.get_client_key(handshake_flow.request.headers) self.client_key = websockets.get_client_key(handshake_flow.request.headers)
@ -134,3 +140,25 @@ class WebSocketFlow(flow.Flow):
direction="->" if message.from_client else "<-", direction="->" if message.from_client else "<-",
endpoint=self.handshake_flow.request.path, endpoint=self.handshake_flow.request.path,
) )
def inject_message(self, endpoint, payload):
"""
Inject and send a full WebSocket message to the remote endpoint.
This might corrupt your WebSocket connection! Be careful!
The endpoint needs to be either flow.client_conn or flow.server_conn.
If ``payload`` is of type ``bytes`` then the message is flagged as
being binary If it is of type ``str`` encoded as UTF-8 and sent as
text.
:param payload: The message body to send.
:type payload: ``bytes`` or ``str``
"""
if endpoint == self.client_conn:
self._inject_messages_client.put(payload)
elif endpoint == self.server_conn:
self._inject_messages_server.put(payload)
else:
raise ValueError('Invalid endpoint')

View File

@ -467,3 +467,46 @@ class TestExtension(_WebSocketTest):
assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY
class TestInjectMessageClient(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
pass
def test_inject_message_client(self):
class Inject:
def websocket_start(self, flow):
flow.inject_message(flow.client_conn, 'This is an injected message!')
self.proxy.set_addons(Inject())
self.setup_connection()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == websockets.OPCODE.TEXT
assert frame.payload == b'This is an injected message!'
class TestInjectMessageServer(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
assert frame.header.opcode == websockets.OPCODE.TEXT
success = frame.payload == b'This is an injected message!'
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=str(success).encode())))
wfile.flush()
def test_inject_message_server(self):
class Inject:
def websocket_start(self, flow):
flow.inject_message(flow.server_conn, 'This is an injected message!')
self.proxy.set_addons(Inject())
self.setup_connection()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == websockets.OPCODE.TEXT
assert frame.payload == b'True'

View File

@ -92,3 +92,15 @@ class TestWebSocketFlow:
assert not f.messages[-1].killed assert not f.messages[-1].killed
f.messages[-1].kill() f.messages[-1].kill()
assert f.messages[-1].killed assert f.messages[-1].killed
def test_inject_message(self):
f = tflow.twebsocketflow()
with pytest.raises(ValueError):
f.inject_message(None, 'foobar')
f.inject_message(f.client_conn, 'foobar')
assert f._inject_messages_client.qsize() == 1
f.inject_message(f.server_conn, 'foobar')
assert f._inject_messages_client.qsize() == 1