mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-29 11:03:13 +00:00
websockets: update protocol detection
This commit is contained in:
parent
e5b0dae7e9
commit
cd3d30633f
@ -28,6 +28,8 @@ Events = frozenset([
|
||||
"response",
|
||||
"responseheaders",
|
||||
|
||||
"websockets_handshake",
|
||||
|
||||
"next_layer",
|
||||
|
||||
"error",
|
||||
|
@ -334,6 +334,10 @@ class FlowMaster(controller.Master):
|
||||
self.client_playback.clear(f)
|
||||
return f
|
||||
|
||||
@controller.handler
|
||||
def websockets_handshake(self, f):
|
||||
return f
|
||||
|
||||
def handle_intercept(self, f):
|
||||
self.state.update_flow(f)
|
||||
|
||||
|
@ -29,8 +29,10 @@ from __future__ import absolute_import, print_function, division
|
||||
|
||||
from .base import Layer, ServerConnectionMixin
|
||||
from .http import UpstreamConnectLayer
|
||||
from .http import HttpLayer
|
||||
from .http1 import Http1Layer
|
||||
from .http2 import Http2Layer
|
||||
from .websockets import WebSocketsLayer
|
||||
from .rawtcp import RawTCPLayer
|
||||
from .tls import TlsClientHello
|
||||
from .tls import TlsLayer
|
||||
@ -40,7 +42,9 @@ __all__ = [
|
||||
"Layer", "ServerConnectionMixin",
|
||||
"TlsLayer", "is_tls_record_magic", "TlsClientHello",
|
||||
"UpstreamConnectLayer",
|
||||
"HttpLayer",
|
||||
"Http1Layer",
|
||||
"Http2Layer",
|
||||
"WebSocketsLayer",
|
||||
"RawTCPLayer",
|
||||
]
|
||||
|
@ -10,7 +10,6 @@ import six
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import models
|
||||
from mitmproxy.protocol import base
|
||||
from .websockets import WebSocketsLayer
|
||||
|
||||
import netlib.exceptions
|
||||
from netlib import http
|
||||
@ -192,20 +191,10 @@ class HttpLayer(base.Layer):
|
||||
self.process_request_hook(flow)
|
||||
|
||||
try:
|
||||
# WebSockets
|
||||
if websockets.check_handshake(request.headers):
|
||||
if websockets.check_client_version(request.headers):
|
||||
layer = WebSocketsLayer(self, request)
|
||||
layer()
|
||||
return
|
||||
else:
|
||||
# we only support RFC6455 with WebSockets version 13
|
||||
self.send_response(models.make_error_response(
|
||||
400,
|
||||
http.status_codes.RESPONSES.get(400),
|
||||
http.Headers(sec_websocket_version="13")
|
||||
))
|
||||
return
|
||||
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
|
||||
# we only support RFC6455 with WebSockets version 13
|
||||
# allow inline scripts to manupulate the client handshake
|
||||
self.channel.ask("websockets_handshake", flow)
|
||||
|
||||
if not flow.response:
|
||||
self.establish_server_connection(
|
||||
@ -230,7 +219,7 @@ class HttpLayer(base.Layer):
|
||||
# It may be useful to pass additional args (such as the upgrade header)
|
||||
# to next_layer in the future
|
||||
if flow.response.status_code == 101:
|
||||
layer = self.ctx.next_layer(self)
|
||||
layer = self.ctx.next_layer(self, flow)
|
||||
layer()
|
||||
return
|
||||
|
||||
|
@ -6,12 +6,10 @@ import struct
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import models
|
||||
from mitmproxy.protocol import base
|
||||
|
||||
import netlib.exceptions
|
||||
from netlib import tcp
|
||||
from netlib import http
|
||||
from netlib import websockets
|
||||
|
||||
|
||||
@ -38,44 +36,17 @@ class WebSocketsLayer(base.Layer):
|
||||
Only raw frames are forwarded to the other endpoint.
|
||||
"""
|
||||
|
||||
def __init__(self, ctx, request):
|
||||
def __init__(self, ctx, flow):
|
||||
super(WebSocketsLayer, self).__init__(ctx)
|
||||
self._request = request
|
||||
self._flow = flow
|
||||
|
||||
self.client_key = websockets.get_client_key(self._request.headers)
|
||||
self.client_protocol = websockets.get_protocol(self._request.headers)
|
||||
self.client_extensions = websockets.get_extensions(self._request.headers)
|
||||
self.client_key = websockets.get_client_key(self._flow.request.headers)
|
||||
self.client_protocol = websockets.get_protocol(self._flow.request.headers)
|
||||
self.client_extensions = websockets.get_extensions(self._flow.request.headers)
|
||||
|
||||
self.server_accept = None
|
||||
self.server_protocol = None
|
||||
self.server_extensions = None
|
||||
|
||||
def _initiate_server_conn(self):
|
||||
self.establish_server_connection(
|
||||
self._request.host,
|
||||
self._request.port,
|
||||
self._request.scheme,
|
||||
)
|
||||
|
||||
self.server_conn.send(netlib.http.http1.assemble_request(self._request))
|
||||
response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None)
|
||||
|
||||
if not websockets.check_handshake(response.headers):
|
||||
raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers))
|
||||
|
||||
self.server_accept = websockets.get_server_accept(response.headers)
|
||||
self.server_protocol = websockets.get_protocol(response.headers)
|
||||
self.server_extensions = websockets.get_extensions(response.headers)
|
||||
|
||||
def _complete_handshake(self):
|
||||
headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions)
|
||||
self.send_response(models.HTTPResponse(
|
||||
self._request.http_version,
|
||||
101,
|
||||
http.status_codes.RESPONSES.get(101),
|
||||
headers,
|
||||
b"",
|
||||
))
|
||||
self.server_accept = websockets.get_server_accept(self._flow.response.headers)
|
||||
self.server_protocol = websockets.get_protocol(self._flow.response.headers)
|
||||
self.server_extensions = websockets.get_extensions(self._flow.response.headers)
|
||||
|
||||
def _handle_frame(self, frame, source_conn, other_conn, is_server):
|
||||
self.log(
|
||||
@ -114,9 +85,6 @@ class WebSocketsLayer(base.Layer):
|
||||
return True
|
||||
|
||||
def __call__(self):
|
||||
self._initiate_server_conn()
|
||||
self._complete_handshake()
|
||||
|
||||
client = self.client_conn.connection
|
||||
server = self.server_conn.connection
|
||||
conns = [client, server]
|
||||
|
@ -4,6 +4,7 @@ import sys
|
||||
|
||||
import six
|
||||
|
||||
from netlib import websockets
|
||||
import netlib.exceptions
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import protocol
|
||||
@ -32,7 +33,7 @@ class RootContext(object):
|
||||
self.channel = channel
|
||||
self.config = config
|
||||
|
||||
def next_layer(self, top_layer):
|
||||
def next_layer(self, top_layer, flow=None):
|
||||
"""
|
||||
This function determines the next layer in the protocol stack.
|
||||
|
||||
@ -42,10 +43,22 @@ class RootContext(object):
|
||||
Returns:
|
||||
The next layer
|
||||
"""
|
||||
layer = self._next_layer(top_layer)
|
||||
layer = self._next_layer(top_layer, flow)
|
||||
return self.channel.ask("next_layer", layer)
|
||||
|
||||
def _next_layer(self, top_layer):
|
||||
def _next_layer(self, top_layer, flow):
|
||||
if flow is not None:
|
||||
# We already have a flow, try to derive the next information from it
|
||||
|
||||
# Check for WebSockets handshake
|
||||
is_websockets = (
|
||||
flow and
|
||||
websockets.check_handshake(flow.request.headers) and
|
||||
websockets.check_handshake(flow.response.headers)
|
||||
)
|
||||
if isinstance(top_layer, protocol.HttpLayer) and is_websockets:
|
||||
return protocol.WebSocketsLayer(top_layer, flow)
|
||||
|
||||
try:
|
||||
d = top_layer.client_conn.rfile.peek(3)
|
||||
except netlib.exceptions.TcpException as e:
|
||||
|
Loading…
Reference in New Issue
Block a user