From cd3d30633fae965044d5f320b5544dfbd039693f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 21 Aug 2016 11:12:41 +0200 Subject: [PATCH] websockets: update protocol detection --- mitmproxy/controller.py | 2 ++ mitmproxy/flow/master.py | 4 +++ mitmproxy/protocol/__init__.py | 4 +++ mitmproxy/protocol/http.py | 21 ++++---------- mitmproxy/protocol/websockets.py | 48 ++++++-------------------------- mitmproxy/proxy/root_context.py | 19 +++++++++++-- 6 files changed, 39 insertions(+), 59 deletions(-) diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index c262b192d..d886af970 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -28,6 +28,8 @@ Events = frozenset([ "response", "responseheaders", + "websockets_handshake", + "next_layer", "error", diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index 0475ef4ec..9cdcc8dd5 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -334,6 +334,10 @@ class FlowMaster(controller.Master): self.client_playback.clear(f) return f + @controller.handler + def websockets_handshake(self, f): + return f + def handle_intercept(self, f): self.state.update_flow(f) diff --git a/mitmproxy/protocol/__init__.py b/mitmproxy/protocol/__init__.py index 510cd1955..b99b55bdd 100644 --- a/mitmproxy/protocol/__init__.py +++ b/mitmproxy/protocol/__init__.py @@ -29,8 +29,10 @@ from __future__ import absolute_import, print_function, division from .base import Layer, ServerConnectionMixin from .http import UpstreamConnectLayer +from .http import HttpLayer from .http1 import Http1Layer from .http2 import Http2Layer +from .websockets import WebSocketsLayer from .rawtcp import RawTCPLayer from .tls import TlsClientHello from .tls import TlsLayer @@ -40,7 +42,9 @@ __all__ = [ "Layer", "ServerConnectionMixin", "TlsLayer", "is_tls_record_magic", "TlsClientHello", "UpstreamConnectLayer", + "HttpLayer", "Http1Layer", "Http2Layer", + "WebSocketsLayer", "RawTCPLayer", ] diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index fbb52c92b..1418d6e9e 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -10,7 +10,6 @@ import six from mitmproxy import exceptions from mitmproxy import models from mitmproxy.protocol import base -from .websockets import WebSocketsLayer import netlib.exceptions from netlib import http @@ -192,20 +191,10 @@ class HttpLayer(base.Layer): self.process_request_hook(flow) try: - # WebSockets - if websockets.check_handshake(request.headers): - if websockets.check_client_version(request.headers): - layer = WebSocketsLayer(self, request) - layer() - return - else: - # we only support RFC6455 with WebSockets version 13 - self.send_response(models.make_error_response( - 400, - http.status_codes.RESPONSES.get(400), - http.Headers(sec_websocket_version="13") - )) - return + if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers): + # we only support RFC6455 with WebSockets version 13 + # allow inline scripts to manupulate the client handshake + self.channel.ask("websockets_handshake", flow) if not flow.response: self.establish_server_connection( @@ -230,7 +219,7 @@ class HttpLayer(base.Layer): # It may be useful to pass additional args (such as the upgrade header) # to next_layer in the future if flow.response.status_code == 101: - layer = self.ctx.next_layer(self) + layer = self.ctx.next_layer(self, flow) layer() return diff --git a/mitmproxy/protocol/websockets.py b/mitmproxy/protocol/websockets.py index 05eaa5372..f15a38ef1 100644 --- a/mitmproxy/protocol/websockets.py +++ b/mitmproxy/protocol/websockets.py @@ -6,12 +6,10 @@ import struct from OpenSSL import SSL from mitmproxy import exceptions -from mitmproxy import models from mitmproxy.protocol import base import netlib.exceptions from netlib import tcp -from netlib import http from netlib import websockets @@ -38,44 +36,17 @@ class WebSocketsLayer(base.Layer): Only raw frames are forwarded to the other endpoint. """ - def __init__(self, ctx, request): + def __init__(self, ctx, flow): super(WebSocketsLayer, self).__init__(ctx) - self._request = request + self._flow = flow - self.client_key = websockets.get_client_key(self._request.headers) - self.client_protocol = websockets.get_protocol(self._request.headers) - self.client_extensions = websockets.get_extensions(self._request.headers) + self.client_key = websockets.get_client_key(self._flow.request.headers) + self.client_protocol = websockets.get_protocol(self._flow.request.headers) + self.client_extensions = websockets.get_extensions(self._flow.request.headers) - self.server_accept = None - self.server_protocol = None - self.server_extensions = None - - def _initiate_server_conn(self): - self.establish_server_connection( - self._request.host, - self._request.port, - self._request.scheme, - ) - - self.server_conn.send(netlib.http.http1.assemble_request(self._request)) - response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None) - - if not websockets.check_handshake(response.headers): - raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers)) - - self.server_accept = websockets.get_server_accept(response.headers) - self.server_protocol = websockets.get_protocol(response.headers) - self.server_extensions = websockets.get_extensions(response.headers) - - def _complete_handshake(self): - headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions) - self.send_response(models.HTTPResponse( - self._request.http_version, - 101, - http.status_codes.RESPONSES.get(101), - headers, - b"", - )) + self.server_accept = websockets.get_server_accept(self._flow.response.headers) + self.server_protocol = websockets.get_protocol(self._flow.response.headers) + self.server_extensions = websockets.get_extensions(self._flow.response.headers) def _handle_frame(self, frame, source_conn, other_conn, is_server): self.log( @@ -114,9 +85,6 @@ class WebSocketsLayer(base.Layer): return True def __call__(self): - self._initiate_server_conn() - self._complete_handshake() - client = self.client_conn.connection server = self.server_conn.connection conns = [client, server] diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 81dd625c4..956113625 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -4,6 +4,7 @@ import sys import six +from netlib import websockets import netlib.exceptions from mitmproxy import exceptions from mitmproxy import protocol @@ -32,7 +33,7 @@ class RootContext(object): self.channel = channel self.config = config - def next_layer(self, top_layer): + def next_layer(self, top_layer, flow=None): """ This function determines the next layer in the protocol stack. @@ -42,10 +43,22 @@ class RootContext(object): Returns: The next layer """ - layer = self._next_layer(top_layer) + layer = self._next_layer(top_layer, flow) return self.channel.ask("next_layer", layer) - def _next_layer(self, top_layer): + def _next_layer(self, top_layer, flow): + if flow is not None: + # We already have a flow, try to derive the next information from it + + # Check for WebSockets handshake + is_websockets = ( + flow and + websockets.check_handshake(flow.request.headers) and + websockets.check_handshake(flow.response.headers) + ) + if isinstance(top_layer, protocol.HttpLayer) and is_websockets: + return protocol.WebSocketsLayer(top_layer, flow) + try: d = top_layer.client_conn.rfile.peek(3) except netlib.exceptions.TcpException as e: