websockets: update protocol detection

This commit is contained in:
Thomas Kriechbaumer 2016-08-21 11:12:41 +02:00
parent e5b0dae7e9
commit cd3d30633f
6 changed files with 39 additions and 59 deletions

View File

@ -28,6 +28,8 @@ Events = frozenset([
"response",
"responseheaders",
"websockets_handshake",
"next_layer",
"error",

View File

@ -334,6 +334,10 @@ class FlowMaster(controller.Master):
self.client_playback.clear(f)
return f
@controller.handler
def websockets_handshake(self, f):
return f
def handle_intercept(self, f):
self.state.update_flow(f)

View File

@ -29,8 +29,10 @@ from __future__ import absolute_import, print_function, division
from .base import Layer, ServerConnectionMixin
from .http import UpstreamConnectLayer
from .http import HttpLayer
from .http1 import Http1Layer
from .http2 import Http2Layer
from .websockets import WebSocketsLayer
from .rawtcp import RawTCPLayer
from .tls import TlsClientHello
from .tls import TlsLayer
@ -40,7 +42,9 @@ __all__ = [
"Layer", "ServerConnectionMixin",
"TlsLayer", "is_tls_record_magic", "TlsClientHello",
"UpstreamConnectLayer",
"HttpLayer",
"Http1Layer",
"Http2Layer",
"WebSocketsLayer",
"RawTCPLayer",
]

View File

@ -10,7 +10,6 @@ import six
from mitmproxy import exceptions
from mitmproxy import models
from mitmproxy.protocol import base
from .websockets import WebSocketsLayer
import netlib.exceptions
from netlib import http
@ -192,20 +191,10 @@ class HttpLayer(base.Layer):
self.process_request_hook(flow)
try:
# WebSockets
if websockets.check_handshake(request.headers):
if websockets.check_client_version(request.headers):
layer = WebSocketsLayer(self, request)
layer()
return
else:
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
# we only support RFC6455 with WebSockets version 13
self.send_response(models.make_error_response(
400,
http.status_codes.RESPONSES.get(400),
http.Headers(sec_websocket_version="13")
))
return
# allow inline scripts to manupulate the client handshake
self.channel.ask("websockets_handshake", flow)
if not flow.response:
self.establish_server_connection(
@ -230,7 +219,7 @@ class HttpLayer(base.Layer):
# It may be useful to pass additional args (such as the upgrade header)
# to next_layer in the future
if flow.response.status_code == 101:
layer = self.ctx.next_layer(self)
layer = self.ctx.next_layer(self, flow)
layer()
return

View File

@ -6,12 +6,10 @@ import struct
from OpenSSL import SSL
from mitmproxy import exceptions
from mitmproxy import models
from mitmproxy.protocol import base
import netlib.exceptions
from netlib import tcp
from netlib import http
from netlib import websockets
@ -38,44 +36,17 @@ class WebSocketsLayer(base.Layer):
Only raw frames are forwarded to the other endpoint.
"""
def __init__(self, ctx, request):
def __init__(self, ctx, flow):
super(WebSocketsLayer, self).__init__(ctx)
self._request = request
self._flow = flow
self.client_key = websockets.get_client_key(self._request.headers)
self.client_protocol = websockets.get_protocol(self._request.headers)
self.client_extensions = websockets.get_extensions(self._request.headers)
self.client_key = websockets.get_client_key(self._flow.request.headers)
self.client_protocol = websockets.get_protocol(self._flow.request.headers)
self.client_extensions = websockets.get_extensions(self._flow.request.headers)
self.server_accept = None
self.server_protocol = None
self.server_extensions = None
def _initiate_server_conn(self):
self.establish_server_connection(
self._request.host,
self._request.port,
self._request.scheme,
)
self.server_conn.send(netlib.http.http1.assemble_request(self._request))
response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None)
if not websockets.check_handshake(response.headers):
raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers))
self.server_accept = websockets.get_server_accept(response.headers)
self.server_protocol = websockets.get_protocol(response.headers)
self.server_extensions = websockets.get_extensions(response.headers)
def _complete_handshake(self):
headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions)
self.send_response(models.HTTPResponse(
self._request.http_version,
101,
http.status_codes.RESPONSES.get(101),
headers,
b"",
))
self.server_accept = websockets.get_server_accept(self._flow.response.headers)
self.server_protocol = websockets.get_protocol(self._flow.response.headers)
self.server_extensions = websockets.get_extensions(self._flow.response.headers)
def _handle_frame(self, frame, source_conn, other_conn, is_server):
self.log(
@ -114,9 +85,6 @@ class WebSocketsLayer(base.Layer):
return True
def __call__(self):
self._initiate_server_conn()
self._complete_handshake()
client = self.client_conn.connection
server = self.server_conn.connection
conns = [client, server]

View File

@ -4,6 +4,7 @@ import sys
import six
from netlib import websockets
import netlib.exceptions
from mitmproxy import exceptions
from mitmproxy import protocol
@ -32,7 +33,7 @@ class RootContext(object):
self.channel = channel
self.config = config
def next_layer(self, top_layer):
def next_layer(self, top_layer, flow=None):
"""
This function determines the next layer in the protocol stack.
@ -42,10 +43,22 @@ class RootContext(object):
Returns:
The next layer
"""
layer = self._next_layer(top_layer)
layer = self._next_layer(top_layer, flow)
return self.channel.ask("next_layer", layer)
def _next_layer(self, top_layer):
def _next_layer(self, top_layer, flow):
if flow is not None:
# We already have a flow, try to derive the next information from it
# Check for WebSockets handshake
is_websockets = (
flow and
websockets.check_handshake(flow.request.headers) and
websockets.check_handshake(flow.response.headers)
)
if isinstance(top_layer, protocol.HttpLayer) and is_websockets:
return protocol.WebSocketsLayer(top_layer, flow)
try:
d = top_layer.client_conn.rfile.peek(3)
except netlib.exceptions.TcpException as e: