add WebSocket flows and messages

This commit is contained in:
Thomas Kriechbaumer 2016-11-13 17:50:51 +01:00
parent ffb3988dc9
commit 3d8f3d4c23
8 changed files with 238 additions and 54 deletions

View File

@ -223,6 +223,21 @@ class Dumper:
if self.match(f):
self.echo_flow(f)
def websocket_error(self, f):
self.echo(
"Error in WebSocket connection to {}: {}".format(
repr(f.server_conn.address), f.error
),
fg="red"
)
def websocket_message(self, f):
if self.match(f):
message = f.messages[-1]
self.echo(message.info)
if self.flow_detail >= 3:
self._echo_message(message)
def tcp_error(self, f):
self.echo(
"Error in TCP connection to {}: {}".format(
@ -240,4 +255,5 @@ class Dumper:
server=repr(f.server_conn.address),
direction=direction,
))
self._echo_message(message)
if self.flow_detail >= 3:
self._echo_message(message)

View File

@ -1,6 +1,7 @@
from mitmproxy import controller
from mitmproxy import http
from mitmproxy import tcp
from mitmproxy import websocket
Events = frozenset([
"clientconnect",
@ -24,6 +25,10 @@ Events = frozenset([
"resume",
"websocket_handshake",
"websocket_start",
"websocket_message",
"websocket_error",
"websocket_end",
"next_layer",
@ -45,6 +50,17 @@ def event_sequence(f):
yield "response", f
if f.error:
yield "error", f
elif isinstance(f, websocket.WebSocketFlow):
messages = f.messages
f.messages = []
f.reply = controller.DummyReply()
yield "websocket_start", f
while messages:
f.messages.append(messages.pop(0))
yield "websocket_message", f
if f.error:
yield "websocket_error", f
yield "websocket_end", f
elif isinstance(f, tcp.TCPFlow):
messages = f.messages
f.messages = []

View File

@ -4,12 +4,14 @@ from mitmproxy import exceptions
from mitmproxy import flowfilter
from mitmproxy import http
from mitmproxy import tcp
from mitmproxy import websocket
from mitmproxy.contrib import tnetstring
from mitmproxy import io_compat
FLOW_TYPES = dict(
http=http.HTTPFlow,
websocket=websocket.WebSocketFlow,
tcp=tcp.TCPFlow,
)

View File

@ -283,6 +283,22 @@ class Master:
def websocket_handshake(self, f):
pass
@controller.handler
def websocket_start(self, flow):
pass
@controller.handler
def websocket_message(self, flow):
pass
@controller.handler
def websocket_error(self, flow):
pass
@controller.handler
def websocket_end(self, flow):
pass
@controller.handler
def tcp_start(self, flow):
pass

View File

@ -1,18 +1,23 @@
import os
import socket
import struct
from OpenSSL import SSL
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
from mitmproxy.utils import strutils
from mitmproxy.net import tcp
from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketBinaryMessage, WebSocketTextMessage
class WebSocketLayer(base.Layer):
"""
WebSocket layer to intercept, modify, and forward WebSocket connections
WebSocket layer to intercept, modify, and forward WebSocket messages.
Only version 13 is supported (as specified in RFC6455)
Only version 13 is supported (as specified in RFC6455).
Only HTTP/1.1-initiated connections are supported.
The client starts by sending an Upgrade-request.
@ -29,65 +34,106 @@ class WebSocketLayer(base.Layer):
This layer is transparent to any negotiated extensions.
This layer is transparent to any negotiated subprotocols.
Only raw frames are forwarded to the other endpoint.
WebSocket messages are stored in a WebSocketFlow.
"""
def __init__(self, ctx, flow):
def __init__(self, ctx, handshake_flow):
super().__init__(ctx)
self._flow = flow
self.handshake_flow = handshake_flow
self.flow = None # type: WebSocketFlow
self.client_key = websockets.get_client_key(self._flow.request.headers)
self.client_protocol = websockets.get_protocol(self._flow.request.headers)
self.client_extensions = websockets.get_extensions(self._flow.request.headers)
self.server_accept = websockets.get_server_accept(self._flow.response.headers)
self.server_protocol = websockets.get_protocol(self._flow.response.headers)
self.server_extensions = websockets.get_extensions(self._flow.response.headers)
self.client_frame_buffer = []
self.server_frame_buffer = []
def _handle_frame(self, frame, source_conn, other_conn, is_server):
sender = "server" if is_server else "client"
self.log(
"WebSocket frame received from {}".format(sender),
"debug",
[repr(frame)]
)
# sender = "server" if is_server else "client"
# self.log(
# "WebSocket frame received from {}".format(sender),
# "debug",
# [repr(frame)]
# )
if frame.header.opcode & 0x8 == 0:
self.log(
"{direction} websocket {direction} {server}".format(
server=repr(self.server_conn.address),
direction="<-" if is_server else "->",
),
"info",
strutils.bytes_to_escaped_str(frame.payload, keep_spacing=True).splitlines()
)
# forward the data frame to the other side
other_conn.send(bytes(frame))
return self._handle_data_frame(frame, source_conn, other_conn, is_server)
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
# just forward the ping/pong to the other side
other_conn.send(bytes(frame))
return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
elif frame.header.opcode == websockets.OPCODE.CLOSE:
code = '(status code missing)'
msg = None
reason = '(message missing)'
if len(frame.payload) >= 2:
code, = struct.unpack('!H', frame.payload[:2])
msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
if len(frame.payload) > 2:
reason = frame.payload[2:]
self.log("WebSocket connection closed by {}: {} {}, {}".format(sender, code, msg, reason), "info")
other_conn.send(bytes(frame))
# close the connection
return False
return self._handle_close(frame, source_conn, other_conn, is_server)
else:
self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
# unknown frame - just forward it
other_conn.send(bytes(frame))
return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
fb = self.server_frame_buffer if is_server else self.client_frame_buffer
fb.append(frame)
if frame.header.fin:
if frame.header.opcode == websockets.OPCODE.TEXT:
t = WebSocketTextMessage
else:
t = WebSocketBinaryMessage
payload = b''.join(f.payload for f in fb)
fb.clear()
websocket_message = t(self.flow, not is_server, payload)
self.flow.messages.append(websocket_message)
self.channel.ask("websocket_message", self.flow)
# chunk payload into multiple 10kB frames, and send them
payload = websocket_message.content
chunk_size = 10240 # 10kB
chunks = range(0, len(payload), chunk_size)
frms = [
websockets.Frame(
payload=payload[i:i + chunk_size],
opcode=frame.header.opcode,
mask=(False if is_server else 1),
masking_key=(b'' if is_server else os.urandom(4))) for i in chunks
]
frms[-1].header.fin = 1
for frm in frms:
other_conn.send(bytes(frm))
return True
def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
# just forward the ping/pong to the other side
other_conn.send(bytes(frame))
return True
def _handle_close(self, frame, source_conn, other_conn, is_server):
code = '(status code missing)'
msg = None
reason = '(message missing)'
if len(frame.payload) >= 2:
code, = struct.unpack('!H', frame.payload[:2])
msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
if len(frame.payload) > 2:
reason = frame.payload[2:]
other_conn.send(bytes(frame))
sender = "server" if is_server else "client"
self.log("WebSocket connection closed by {}: {} {}, {}".format(sender, code, msg, reason), "info")
# close the connection
return False
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
# unknown frame - just forward it
other_conn.send(bytes(frame))
sender = "server" if is_server else "client"
self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
# continue the connection
return True
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
self.channel.ask("websocket_start", self.flow)
client = self.client_conn.connection
server = self.server_conn.connection
conns = [client, server]
@ -105,7 +151,7 @@ class WebSocketLayer(base.Layer):
if not self._handle_frame(frame, source_conn, other_conn, is_server):
return
except (socket.error, exceptions.TcpException, SSL.Error) as e:
self.log("WebSocket connection closed unexpectedly by {}: {}".format(
"server" if is_server else "client", repr(e)), "info")
except Exception as e: # pragma: no cover
raise exceptions.ProtocolException("Error in WebSocket connection: {}".format(repr(e)))
self.flow.error = flow.Error("WebSocket connection closed unexpectedly: {}".format(repr(e)))
self.channel.tell("websocket_error", self.flow)
finally:
self.channel.tell("websocket_end", self.flow)

View File

@ -11,9 +11,7 @@ class TCPMessage(serializable.Serializable):
def __init__(self, from_client, content, timestamp=None):
self.content = content
self.from_client = from_client
if timestamp is None:
timestamp = time.time()
self.timestamp = timestamp
self.timestamp = timestamp or time.time()
@classmethod
def from_state(cls, state):

View File

@ -446,6 +446,13 @@ class ConsoleMaster(master.Master):
self.logbuffer[:] = []
# Handlers
@controller.handler
def websocket_message(self, f):
super().websocket_message(f)
message = f.messages[-1]
self.add_log(message.info, "info")
self.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
@controller.handler
def tcp_message(self, f):
super().tcp_message(f)

83
mitmproxy/websocket.py Normal file
View File

@ -0,0 +1,83 @@
import time
from typing import List
from mitmproxy import flow
from mitmproxy.http import HTTPFlow
from mitmproxy.net import websockets
from mitmproxy.utils import strutils
from mitmproxy.types import serializable
class WebSocketMessage(serializable.Serializable):
def __init__(self, flow, from_client, content, timestamp=None):
self.flow = flow
self.content = content
self.from_client = from_client
self.timestamp = timestamp or time.time()
@classmethod
def from_state(cls, state):
return cls(*state)
def get_state(self):
return self.from_client, self.content, self.timestamp
def set_state(self, state):
self.from_client = state.pop("from_client")
self.content = state.pop("content")
self.timestamp = state.pop("timestamp")
@property
def info(self):
return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
type=self.type,
client=repr(self.flow.client_conn.address),
server=repr(self.flow.server_conn.address),
direction="->" if self.from_client else "<-",
endpoint=self.flow.handshake_flow.request.path,
)
class WebSocketBinaryMessage(WebSocketMessage):
type = 'binary'
def __repr__(self):
return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content))
class WebSocketTextMessage(WebSocketMessage):
type = 'text'
def __repr__(self):
return "text message: {}".format(repr(self.content))
class WebSocketFlow(flow.Flow):
"""
A WebsocketFlow is a simplified representation of a Websocket session.
"""
def __init__(self, client_conn, server_conn, handshake_flow, live=None):
super().__init__("websocket", client_conn, server_conn, live)
self.messages = [] # type: List[WebSocketMessage]
self.handshake_flow = handshake_flow
self.client_key = websockets.get_client_key(self.handshake_flow.request.headers)
self.client_protocol = websockets.get_protocol(self.handshake_flow.request.headers)
self.client_extensions = websockets.get_extensions(self.handshake_flow.request.headers)
self.server_accept = websockets.get_server_accept(self.handshake_flow.response.headers)
self.server_protocol = websockets.get_protocol(self.handshake_flow.response.headers)
self.server_extensions = websockets.get_extensions(self.handshake_flow.response.headers)
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
_stateobject_attributes.update(
messages=List[WebSocketMessage],
handshake_flow=HTTPFlow,
)
def __repr__(self):
return "<WebSocketFlow ({} messages)>".format(len(self.messages))