websockets: update protocol detection

This commit is contained in:
Thomas Kriechbaumer 2016-08-21 11:12:41 +02:00
parent e5b0dae7e9
commit cd3d30633f
6 changed files with 39 additions and 59 deletions

View File

@ -28,6 +28,8 @@ Events = frozenset([
"response", "response",
"responseheaders", "responseheaders",
"websockets_handshake",
"next_layer", "next_layer",
"error", "error",

View File

@ -334,6 +334,10 @@ class FlowMaster(controller.Master):
self.client_playback.clear(f) self.client_playback.clear(f)
return f return f
@controller.handler
def websockets_handshake(self, f):
return f
def handle_intercept(self, f): def handle_intercept(self, f):
self.state.update_flow(f) self.state.update_flow(f)

View File

@ -29,8 +29,10 @@ from __future__ import absolute_import, print_function, division
from .base import Layer, ServerConnectionMixin from .base import Layer, ServerConnectionMixin
from .http import UpstreamConnectLayer from .http import UpstreamConnectLayer
from .http import HttpLayer
from .http1 import Http1Layer from .http1 import Http1Layer
from .http2 import Http2Layer from .http2 import Http2Layer
from .websockets import WebSocketsLayer
from .rawtcp import RawTCPLayer from .rawtcp import RawTCPLayer
from .tls import TlsClientHello from .tls import TlsClientHello
from .tls import TlsLayer from .tls import TlsLayer
@ -40,7 +42,9 @@ __all__ = [
"Layer", "ServerConnectionMixin", "Layer", "ServerConnectionMixin",
"TlsLayer", "is_tls_record_magic", "TlsClientHello", "TlsLayer", "is_tls_record_magic", "TlsClientHello",
"UpstreamConnectLayer", "UpstreamConnectLayer",
"HttpLayer",
"Http1Layer", "Http1Layer",
"Http2Layer", "Http2Layer",
"WebSocketsLayer",
"RawTCPLayer", "RawTCPLayer",
] ]

View File

@ -10,7 +10,6 @@ import six
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import models from mitmproxy import models
from mitmproxy.protocol import base from mitmproxy.protocol import base
from .websockets import WebSocketsLayer
import netlib.exceptions import netlib.exceptions
from netlib import http from netlib import http
@ -192,20 +191,10 @@ class HttpLayer(base.Layer):
self.process_request_hook(flow) self.process_request_hook(flow)
try: try:
# WebSockets if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
if websockets.check_handshake(request.headers):
if websockets.check_client_version(request.headers):
layer = WebSocketsLayer(self, request)
layer()
return
else:
# we only support RFC6455 with WebSockets version 13 # we only support RFC6455 with WebSockets version 13
self.send_response(models.make_error_response( # allow inline scripts to manupulate the client handshake
400, self.channel.ask("websockets_handshake", flow)
http.status_codes.RESPONSES.get(400),
http.Headers(sec_websocket_version="13")
))
return
if not flow.response: if not flow.response:
self.establish_server_connection( self.establish_server_connection(
@ -230,7 +219,7 @@ class HttpLayer(base.Layer):
# It may be useful to pass additional args (such as the upgrade header) # It may be useful to pass additional args (such as the upgrade header)
# to next_layer in the future # to next_layer in the future
if flow.response.status_code == 101: if flow.response.status_code == 101:
layer = self.ctx.next_layer(self) layer = self.ctx.next_layer(self, flow)
layer() layer()
return return

View File

@ -6,12 +6,10 @@ import struct
from OpenSSL import SSL from OpenSSL import SSL
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import models
from mitmproxy.protocol import base from mitmproxy.protocol import base
import netlib.exceptions import netlib.exceptions
from netlib import tcp from netlib import tcp
from netlib import http
from netlib import websockets from netlib import websockets
@ -38,44 +36,17 @@ class WebSocketsLayer(base.Layer):
Only raw frames are forwarded to the other endpoint. Only raw frames are forwarded to the other endpoint.
""" """
def __init__(self, ctx, request): def __init__(self, ctx, flow):
super(WebSocketsLayer, self).__init__(ctx) super(WebSocketsLayer, self).__init__(ctx)
self._request = request self._flow = flow
self.client_key = websockets.get_client_key(self._request.headers) self.client_key = websockets.get_client_key(self._flow.request.headers)
self.client_protocol = websockets.get_protocol(self._request.headers) self.client_protocol = websockets.get_protocol(self._flow.request.headers)
self.client_extensions = websockets.get_extensions(self._request.headers) self.client_extensions = websockets.get_extensions(self._flow.request.headers)
self.server_accept = None self.server_accept = websockets.get_server_accept(self._flow.response.headers)
self.server_protocol = None self.server_protocol = websockets.get_protocol(self._flow.response.headers)
self.server_extensions = None self.server_extensions = websockets.get_extensions(self._flow.response.headers)
def _initiate_server_conn(self):
self.establish_server_connection(
self._request.host,
self._request.port,
self._request.scheme,
)
self.server_conn.send(netlib.http.http1.assemble_request(self._request))
response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None)
if not websockets.check_handshake(response.headers):
raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers))
self.server_accept = websockets.get_server_accept(response.headers)
self.server_protocol = websockets.get_protocol(response.headers)
self.server_extensions = websockets.get_extensions(response.headers)
def _complete_handshake(self):
headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions)
self.send_response(models.HTTPResponse(
self._request.http_version,
101,
http.status_codes.RESPONSES.get(101),
headers,
b"",
))
def _handle_frame(self, frame, source_conn, other_conn, is_server): def _handle_frame(self, frame, source_conn, other_conn, is_server):
self.log( self.log(
@ -114,9 +85,6 @@ class WebSocketsLayer(base.Layer):
return True return True
def __call__(self): def __call__(self):
self._initiate_server_conn()
self._complete_handshake()
client = self.client_conn.connection client = self.client_conn.connection
server = self.server_conn.connection server = self.server_conn.connection
conns = [client, server] conns = [client, server]

View File

@ -4,6 +4,7 @@ import sys
import six import six
from netlib import websockets
import netlib.exceptions import netlib.exceptions
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import protocol from mitmproxy import protocol
@ -32,7 +33,7 @@ class RootContext(object):
self.channel = channel self.channel = channel
self.config = config self.config = config
def next_layer(self, top_layer): def next_layer(self, top_layer, flow=None):
""" """
This function determines the next layer in the protocol stack. This function determines the next layer in the protocol stack.
@ -42,10 +43,22 @@ class RootContext(object):
Returns: Returns:
The next layer The next layer
""" """
layer = self._next_layer(top_layer) layer = self._next_layer(top_layer, flow)
return self.channel.ask("next_layer", layer) return self.channel.ask("next_layer", layer)
def _next_layer(self, top_layer): def _next_layer(self, top_layer, flow):
if flow is not None:
# We already have a flow, try to derive the next information from it
# Check for WebSockets handshake
is_websockets = (
flow and
websockets.check_handshake(flow.request.headers) and
websockets.check_handshake(flow.response.headers)
)
if isinstance(top_layer, protocol.HttpLayer) and is_websockets:
return protocol.WebSocketsLayer(top_layer, flow)
try: try:
d = top_layer.client_conn.rfile.peek(3) d = top_layer.client_conn.rfile.peek(3)
except netlib.exceptions.TcpException as e: except netlib.exceptions.TcpException as e: