mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
Merge pull request #2545 from mitmproxy/wsproto
Replace our WebSocket stack with wsproto
This commit is contained in:
commit
dfcf62ff2b
@ -13,7 +13,7 @@ mechanism:
|
|||||||
away. Note that mitmproxy's "Limit" option is often the better alternative here, as it is
|
away. Note that mitmproxy's "Limit" option is often the better alternative here, as it is
|
||||||
not affected by the limitations listed below.
|
not affected by the limitations listed below.
|
||||||
|
|
||||||
If you want to peek into (SSL-protected) non-HTTP connections, check out the :ref:`tcpproxy`
|
If you want to peek into (SSL-protected) non-HTTP connections, check out the :ref:`tcp_proxy`
|
||||||
feature.
|
feature.
|
||||||
If you want to ignore traffic from mitmproxy's processing because of large response bodies,
|
If you want to ignore traffic from mitmproxy's processing because of large response bodies,
|
||||||
take a look at the :ref:`streaming` feature.
|
take a look at the :ref:`streaming` feature.
|
||||||
@ -88,7 +88,7 @@ Here are some other examples for ignore patterns:
|
|||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
- :ref:`tcpproxy`
|
- :ref:`tcp_proxy`
|
||||||
- :ref:`streaming`
|
- :ref:`streaming`
|
||||||
- mitmproxy's "Limit" feature
|
- mitmproxy's "Limit" feature
|
||||||
|
|
||||||
|
@ -20,6 +20,15 @@
|
|||||||
mitmweb
|
mitmweb
|
||||||
config
|
config
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:hidden:
|
||||||
|
:caption: Protocols
|
||||||
|
|
||||||
|
protocols/http1
|
||||||
|
protocols/http2
|
||||||
|
protocols/websocket
|
||||||
|
protocols/tcpproxy
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:hidden:
|
:hidden:
|
||||||
:caption: Features
|
:caption: Features
|
||||||
@ -36,7 +45,6 @@
|
|||||||
features/streaming
|
features/streaming
|
||||||
features/socksproxy
|
features/socksproxy
|
||||||
features/sticky
|
features/sticky
|
||||||
features/tcpproxy
|
|
||||||
features/upstreamproxy
|
features/upstreamproxy
|
||||||
features/upstreamcerts
|
features/upstreamcerts
|
||||||
|
|
||||||
|
15
docs/protocols/http1.rst
Normal file
15
docs/protocols/http1.rst
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
.. _http1_protocol:
|
||||||
|
|
||||||
|
HTTP/1.0 and HTTP/1.1
|
||||||
|
===========================
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
- `RFC7230: HTTP/1.1: Message Syntax and Routing <http://tools.ietf.org/html/rfc7230>`_
|
||||||
|
- `RFC7231: HTTP/1.1: Semantics and Content <http://tools.ietf.org/html/rfc7231>`_
|
||||||
|
|
||||||
|
HTTP/1.0 and HTTP/1.1 support in mitmproxy is based on our custom HTTP stack,
|
||||||
|
which takes care of all semantics and on-the-wire parsing/serialization tasks.
|
||||||
|
|
||||||
|
mitmproxy currently does not support HTTP trailers - but if you want to send
|
||||||
|
us a PR, we promise to take look!
|
16
docs/protocols/http2.rst
Normal file
16
docs/protocols/http2.rst
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
.. _http2_protocol:
|
||||||
|
|
||||||
|
HTTP/2
|
||||||
|
======
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
- `RFC7540: Hypertext Transfer Protocol Version 2 (HTTP/2) <http://tools.ietf.org/html/rfc7540>`_
|
||||||
|
|
||||||
|
HTTP/2 support in mitmproxy is based on the amazing work by the python-hyper
|
||||||
|
community with the `hyper-h2 <https://github.com/python-hyper/hyper-h2>`_
|
||||||
|
project. It fully encapsulates the internal state of HTTP/2 connections and
|
||||||
|
provides an easy-to-use event-based API.
|
||||||
|
|
||||||
|
mitmproxy currently does not support HTTP/2 trailers - but if you want to send
|
||||||
|
us a PR, we promise to take look!
|
@ -1,7 +1,7 @@
|
|||||||
.. _tcpproxy:
|
.. _tcp_proxy:
|
||||||
|
|
||||||
TCP Proxy
|
TCP Proxy / Fallback
|
||||||
=========
|
====================
|
||||||
|
|
||||||
In case mitmproxy does not handle a specific protocol, you can exempt
|
In case mitmproxy does not handle a specific protocol, you can exempt
|
||||||
hostnames from processing, so that mitmproxy acts as a generic TCP forwarder.
|
hostnames from processing, so that mitmproxy acts as a generic TCP forwarder.
|
22
docs/protocols/websocket.rst
Normal file
22
docs/protocols/websocket.rst
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
.. _websocket_protocol:
|
||||||
|
|
||||||
|
WebSocket
|
||||||
|
=========
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
- `RFC6455: The WebSocket Protocol <http://tools.ietf.org/html/rfc6455>`_
|
||||||
|
- `RFC7692: Compression Extensions for WebSocket <http://tools.ietf.org/html/rfc7692>`_
|
||||||
|
|
||||||
|
WebSocket support in mitmproxy is based on the amazing work by the python-hyper
|
||||||
|
community with the `wsproto <https://github.com/python-hyper/wsproto>`_
|
||||||
|
project. It fully encapsulates WebSocket frames/messages/connections and
|
||||||
|
provides an easy-to-use event-based API.
|
||||||
|
|
||||||
|
mitmproxy fully supports the compression extension for WebSocket messages,
|
||||||
|
provided by wsproto.
|
||||||
|
|
||||||
|
If an endpoint sends a PING to mitmproxy, a PONG will be sent back immediately
|
||||||
|
(with the same payload if present). To keep the other connection alive, a new
|
||||||
|
PING (without a payload) is sent to the other endpoint. Unsolicited PONG's are
|
||||||
|
not forwarded. All PING's and PONG's are logged (with payload if present).
|
@ -211,7 +211,7 @@ TCP Events
|
|||||||
----------
|
----------
|
||||||
|
|
||||||
These events are called only if the connection is in :ref:`TCP mode
|
These events are called only if the connection is in :ref:`TCP mode
|
||||||
<tcpproxy>`. So, for instance, TCP events are not called for ordinary HTTP/S
|
<tcp_proxy>`. So, for instance, TCP events are not called for ordinary HTTP/S
|
||||||
connections.
|
connections.
|
||||||
|
|
||||||
.. list-table::
|
.. list-table::
|
||||||
|
@ -234,6 +234,8 @@ class Dumper:
|
|||||||
message = f.messages[-1]
|
message = f.messages[-1]
|
||||||
self.echo(f.message_info(message))
|
self.echo(f.message_info(message))
|
||||||
if ctx.options.flow_detail >= 3:
|
if ctx.options.flow_detail >= 3:
|
||||||
|
message = message.from_state(message.get_state())
|
||||||
|
message.content = message.content.encode() if isinstance(message.content, str) else message.content
|
||||||
self._echo_message(message)
|
self._echo_message(message)
|
||||||
|
|
||||||
def websocket_end(self, f):
|
def websocket_end(self, f):
|
||||||
|
20
mitmproxy/contrib/wsproto/compat.py
Normal file
20
mitmproxy/contrib/wsproto/compat.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
PY2 = sys.version_info.major == 2
|
||||||
|
PY3 = sys.version_info.major == 3
|
||||||
|
|
||||||
|
|
||||||
|
if PY3:
|
||||||
|
unicode = str
|
||||||
|
|
||||||
|
def Utf8Validator():
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
unicode = unicode
|
||||||
|
try:
|
||||||
|
from wsaccel.utf8validator import Utf8Validator
|
||||||
|
except ImportError:
|
||||||
|
from .utf8validator import Utf8Validator
|
477
mitmproxy/contrib/wsproto/connection.py
Normal file
477
mitmproxy/contrib/wsproto/connection.py
Normal file
@ -0,0 +1,477 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
wsproto/connection
|
||||||
|
~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
An implementation of a WebSocket connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import h11
|
||||||
|
|
||||||
|
from .events import (
|
||||||
|
ConnectionRequested, ConnectionEstablished, ConnectionClosed,
|
||||||
|
ConnectionFailed, TextReceived, BytesReceived, PingReceived, PongReceived
|
||||||
|
)
|
||||||
|
from .frame_protocol import FrameProtocol, ParseFailed, CloseReason, Opcode
|
||||||
|
|
||||||
|
|
||||||
|
# RFC6455, Section 1.3 - Opening Handshake
|
||||||
|
ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionState(Enum):
|
||||||
|
"""
|
||||||
|
RFC 6455, Section 4 - Opening Handshake
|
||||||
|
"""
|
||||||
|
CONNECTING = 0
|
||||||
|
OPEN = 1
|
||||||
|
CLOSING = 2
|
||||||
|
CLOSED = 3
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionType(Enum):
|
||||||
|
CLIENT = 1
|
||||||
|
SERVER = 2
|
||||||
|
|
||||||
|
|
||||||
|
CLIENT = ConnectionType.CLIENT
|
||||||
|
SERVER = ConnectionType.SERVER
|
||||||
|
|
||||||
|
|
||||||
|
# Some convenience utilities for working with HTTP headers
|
||||||
|
def _normed_header_dict(h11_headers):
|
||||||
|
# This mangles Set-Cookie headers. But it happens that we don't care about
|
||||||
|
# any of those, so it's OK. For every other HTTP header, if there are
|
||||||
|
# multiple instances then you're allowed to join them together with
|
||||||
|
# commas.
|
||||||
|
name_to_values = {}
|
||||||
|
for name, value in h11_headers:
|
||||||
|
name_to_values.setdefault(name, []).append(value)
|
||||||
|
name_to_normed_value = {}
|
||||||
|
for name, values in name_to_values.items():
|
||||||
|
name_to_normed_value[name] = b", ".join(values)
|
||||||
|
return name_to_normed_value
|
||||||
|
|
||||||
|
|
||||||
|
# We use this for parsing the proposed protocol list, and for parsing the
|
||||||
|
# proposed and accepted extension lists. For the proposed protocol list it's
|
||||||
|
# fine, because the ABNF is just 1#token. But for the extension lists, it's
|
||||||
|
# wrong, because those can contain quoted strings, which can in turn contain
|
||||||
|
# commas. XX FIXME
|
||||||
|
def _split_comma_header(value):
|
||||||
|
return [piece.decode('ascii').strip() for piece in value.split(b',')]
|
||||||
|
|
||||||
|
|
||||||
|
class WSConnection(object):
|
||||||
|
"""
|
||||||
|
A low-level WebSocket connection object.
|
||||||
|
|
||||||
|
This wraps two other protocol objects, an HTTP/1.1 protocol object used
|
||||||
|
to do the initial HTTP upgrade handshake and a WebSocket frame protocol
|
||||||
|
object used to exchange messages and other control frames.
|
||||||
|
|
||||||
|
:param conn_type: Whether this object is on the client- or server-side of
|
||||||
|
a connection. To initialise as a client pass ``CLIENT`` otherwise
|
||||||
|
pass ``SERVER``.
|
||||||
|
:type conn_type: ``ConnectionType``
|
||||||
|
|
||||||
|
:param host: The hostname to pass to the server when acting as a client.
|
||||||
|
:type host: ``str``
|
||||||
|
|
||||||
|
:param resource: The resource (aka path) to pass to the server when acting
|
||||||
|
as a client.
|
||||||
|
:type resource: ``str``
|
||||||
|
|
||||||
|
:param extensions: A list of extensions to use on this connection.
|
||||||
|
Extensions should be instances of a subclass of
|
||||||
|
:class:`Extension <wsproto.extensions.Extension>`.
|
||||||
|
|
||||||
|
:param subprotocols: A list of subprotocols to request when acting as a
|
||||||
|
client, ordered by preference. This has no impact on the connection
|
||||||
|
itself.
|
||||||
|
:type subprotocol: ``list`` of ``str``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, conn_type, host=None, resource=None, extensions=None,
|
||||||
|
subprotocols=None):
|
||||||
|
self.client = conn_type is ConnectionType.CLIENT
|
||||||
|
|
||||||
|
self.host = host
|
||||||
|
self.resource = resource
|
||||||
|
|
||||||
|
self.subprotocols = subprotocols or []
|
||||||
|
self.extensions = extensions or []
|
||||||
|
|
||||||
|
self.version = b'13'
|
||||||
|
|
||||||
|
self._state = ConnectionState.CONNECTING
|
||||||
|
self._close_reason = None
|
||||||
|
|
||||||
|
self._nonce = None
|
||||||
|
self._outgoing = b''
|
||||||
|
self._events = deque()
|
||||||
|
self._proto = None
|
||||||
|
|
||||||
|
if self.client:
|
||||||
|
self._upgrade_connection = h11.Connection(h11.CLIENT)
|
||||||
|
else:
|
||||||
|
self._upgrade_connection = h11.Connection(h11.SERVER)
|
||||||
|
|
||||||
|
if self.client:
|
||||||
|
if self.host is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Host must not be None for a client-side connection.")
|
||||||
|
if self.resource is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Resource must not be None for a client-side connection.")
|
||||||
|
self.initiate_connection()
|
||||||
|
|
||||||
|
def initiate_connection(self):
|
||||||
|
self._generate_nonce()
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
b"Host": self.host.encode('ascii'),
|
||||||
|
b"Upgrade": b'WebSocket',
|
||||||
|
b"Connection": b'Upgrade',
|
||||||
|
b"Sec-WebSocket-Key": self._nonce,
|
||||||
|
b"Sec-WebSocket-Version": self.version,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.subprotocols:
|
||||||
|
headers[b"Sec-WebSocket-Protocol"] = ", ".join(self.subprotocols)
|
||||||
|
|
||||||
|
if self.extensions:
|
||||||
|
offers = {e.name: e.offer(self) for e in self.extensions}
|
||||||
|
extensions = []
|
||||||
|
for name, params in offers.items():
|
||||||
|
if params is True:
|
||||||
|
extensions.append(name.encode('ascii'))
|
||||||
|
elif params:
|
||||||
|
# py34 annoyance: doesn't support bytestring formatting
|
||||||
|
extensions.append(('%s; %s' % (name, params))
|
||||||
|
.encode("ascii"))
|
||||||
|
if extensions:
|
||||||
|
headers[b'Sec-WebSocket-Extensions'] = b', '.join(extensions)
|
||||||
|
|
||||||
|
upgrade = h11.Request(method=b'GET', target=self.resource,
|
||||||
|
headers=headers.items())
|
||||||
|
self._outgoing += self._upgrade_connection.send(upgrade)
|
||||||
|
|
||||||
|
def send_data(self, payload, final=True):
|
||||||
|
"""
|
||||||
|
Send a message or part of a message to the remote peer.
|
||||||
|
|
||||||
|
If ``final`` is ``False`` it indicates that this is part of a longer
|
||||||
|
message. If ``final`` is ``True`` it indicates that this is either a
|
||||||
|
self-contained message or the last part of a longer message.
|
||||||
|
|
||||||
|
If ``payload`` is of type ``bytes`` then the message is flagged as
|
||||||
|
being binary If it is of type ``str`` encoded as UTF-8 and sent as
|
||||||
|
text.
|
||||||
|
|
||||||
|
:param payload: The message body to send.
|
||||||
|
:type payload: ``bytes`` or ``str``
|
||||||
|
|
||||||
|
:param final: Whether there are more parts to this message to be sent.
|
||||||
|
:type final: ``bool``
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._outgoing += self._proto.send_data(payload, final)
|
||||||
|
|
||||||
|
def close(self, code=CloseReason.NORMAL_CLOSURE, reason=None):
|
||||||
|
self._outgoing += self._proto.close(code, reason)
|
||||||
|
self._state = ConnectionState.CLOSING
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self):
|
||||||
|
return self._state is ConnectionState.CLOSED
|
||||||
|
|
||||||
|
def bytes_to_send(self, amount=None):
|
||||||
|
"""
|
||||||
|
Return any data that is to be sent to the remote peer.
|
||||||
|
|
||||||
|
:param amount: (optional) The maximum number of bytes to be provided.
|
||||||
|
If ``None`` or not provided it will return all available bytes.
|
||||||
|
:type amount: ``int``
|
||||||
|
"""
|
||||||
|
|
||||||
|
if amount is None:
|
||||||
|
data = self._outgoing
|
||||||
|
self._outgoing = b''
|
||||||
|
else:
|
||||||
|
data = self._outgoing[:amount]
|
||||||
|
self._outgoing = self._outgoing[amount:]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def receive_bytes(self, data):
|
||||||
|
"""
|
||||||
|
Pass some received bytes to the connection for processing.
|
||||||
|
|
||||||
|
:param data: The data received from the remote peer.
|
||||||
|
:type data: ``bytes``
|
||||||
|
"""
|
||||||
|
|
||||||
|
if data is None and self._state is ConnectionState.OPEN:
|
||||||
|
# "If _The WebSocket Connection is Closed_ and no Close control
|
||||||
|
# frame was received by the endpoint (such as could occur if the
|
||||||
|
# underlying transport connection is lost), _The WebSocket
|
||||||
|
# Connection Close Code_ is considered to be 1006."
|
||||||
|
self._events.append(ConnectionClosed(CloseReason.ABNORMAL_CLOSURE))
|
||||||
|
self._state = ConnectionState.CLOSED
|
||||||
|
return
|
||||||
|
elif data is None:
|
||||||
|
self._state = ConnectionState.CLOSED
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._state is ConnectionState.CONNECTING:
|
||||||
|
event, data = self._process_upgrade(data)
|
||||||
|
if event is not None:
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
if self._state is ConnectionState.OPEN:
|
||||||
|
self._proto.receive_bytes(data)
|
||||||
|
|
||||||
|
def _process_upgrade(self, data):
|
||||||
|
self._upgrade_connection.receive_data(data)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = self._upgrade_connection.next_event()
|
||||||
|
except h11.RemoteProtocolError:
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Bad HTTP message"), b''
|
||||||
|
if event is h11.NEED_DATA:
|
||||||
|
break
|
||||||
|
elif self.client and isinstance(event, (h11.InformationalResponse,
|
||||||
|
h11.Response)):
|
||||||
|
data = self._upgrade_connection.trailing_data[0]
|
||||||
|
return self._establish_client_connection(event), data
|
||||||
|
elif not self.client and isinstance(event, h11.Request):
|
||||||
|
return self._process_connection_request(event), None
|
||||||
|
else:
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Bad HTTP message"), b''
|
||||||
|
|
||||||
|
self._incoming = b''
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def events(self):
|
||||||
|
"""
|
||||||
|
Return a generator that provides any events that have been generated
|
||||||
|
by protocol activity.
|
||||||
|
|
||||||
|
:returns: generator
|
||||||
|
"""
|
||||||
|
|
||||||
|
while self._events:
|
||||||
|
yield self._events.popleft()
|
||||||
|
|
||||||
|
if self._proto is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
for frame in self._proto.received_frames():
|
||||||
|
if frame.opcode is Opcode.PING:
|
||||||
|
assert frame.frame_finished and frame.message_finished
|
||||||
|
self._outgoing += self._proto.pong(frame.payload)
|
||||||
|
yield PingReceived(frame.payload)
|
||||||
|
|
||||||
|
elif frame.opcode is Opcode.PONG:
|
||||||
|
assert frame.frame_finished and frame.message_finished
|
||||||
|
yield PongReceived(frame.payload)
|
||||||
|
|
||||||
|
elif frame.opcode is Opcode.CLOSE:
|
||||||
|
code, reason = frame.payload
|
||||||
|
self.close(code, reason)
|
||||||
|
yield ConnectionClosed(code, reason)
|
||||||
|
|
||||||
|
elif frame.opcode is Opcode.TEXT:
|
||||||
|
yield TextReceived(frame.payload,
|
||||||
|
frame.frame_finished,
|
||||||
|
frame.message_finished)
|
||||||
|
|
||||||
|
elif frame.opcode is Opcode.BINARY:
|
||||||
|
yield BytesReceived(frame.payload,
|
||||||
|
frame.frame_finished,
|
||||||
|
frame.message_finished)
|
||||||
|
except ParseFailed as exc:
|
||||||
|
# XX FIXME: apparently autobahn intentionally deviates from the
|
||||||
|
# spec in that on protocol errors it just closes the connection
|
||||||
|
# rather than trying to send a CLOSE frame. Investigate whether we
|
||||||
|
# should do the same.
|
||||||
|
self.close(code=exc.code, reason=str(exc))
|
||||||
|
yield ConnectionClosed(exc.code, reason=str(exc))
|
||||||
|
|
||||||
|
def _generate_nonce(self):
|
||||||
|
# os.urandom may be overkill for this use case, but I don't think this
|
||||||
|
# is a bottleneck, and better safe than sorry...
|
||||||
|
self._nonce = base64.b64encode(os.urandom(16))
|
||||||
|
|
||||||
|
def _generate_accept_token(self, token):
|
||||||
|
accept_token = token + ACCEPT_GUID
|
||||||
|
accept_token = hashlib.sha1(accept_token).digest()
|
||||||
|
return base64.b64encode(accept_token)
|
||||||
|
|
||||||
|
def _establish_client_connection(self, event):
|
||||||
|
if event.status_code != 101:
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Bad status code from server")
|
||||||
|
headers = _normed_header_dict(event.headers)
|
||||||
|
if headers[b'connection'].lower() != b'upgrade':
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Missing Connection: Upgrade header")
|
||||||
|
if headers[b'upgrade'].lower() != b'websocket':
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Missing Upgrade: WebSocket header")
|
||||||
|
|
||||||
|
accept_token = self._generate_accept_token(self._nonce)
|
||||||
|
if headers[b'sec-websocket-accept'] != accept_token:
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Bad accept token")
|
||||||
|
|
||||||
|
subprotocol = headers.get(b'sec-websocket-protocol', None)
|
||||||
|
if subprotocol is not None:
|
||||||
|
subprotocol = subprotocol.decode('ascii')
|
||||||
|
if subprotocol not in self.subprotocols:
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"unrecognized subprotocol {!r}"
|
||||||
|
.format(subprotocol))
|
||||||
|
|
||||||
|
extensions = headers.get(b'sec-websocket-extensions', None)
|
||||||
|
if extensions:
|
||||||
|
accepts = _split_comma_header(extensions)
|
||||||
|
|
||||||
|
for accept in accepts:
|
||||||
|
name = accept.split(';', 1)[0].strip()
|
||||||
|
for extension in self.extensions:
|
||||||
|
if extension.name == name:
|
||||||
|
extension.finalize(self, accept)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"unrecognized extension {!r}"
|
||||||
|
.format(name))
|
||||||
|
|
||||||
|
self._proto = FrameProtocol(self.client, self.extensions)
|
||||||
|
self._state = ConnectionState.OPEN
|
||||||
|
return ConnectionEstablished(subprotocol, extensions)
|
||||||
|
|
||||||
|
def _process_connection_request(self, event):
|
||||||
|
if event.method != b'GET':
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Request method must be GET")
|
||||||
|
headers = _normed_header_dict(event.headers)
|
||||||
|
if headers[b'connection'].lower() != b'upgrade':
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Missing Connection: Upgrade header")
|
||||||
|
if headers[b'upgrade'].lower() != b'websocket':
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Missing Upgrade: WebSocket header")
|
||||||
|
|
||||||
|
if b'sec-websocket-version' not in headers:
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Missing Sec-WebSocket-Version header")
|
||||||
|
# XX FIXME: need to check Sec-Websocket-Version, and respond with a
|
||||||
|
# 400 if it's not what we expect
|
||||||
|
|
||||||
|
if b'sec-websocket-protocol' in headers:
|
||||||
|
proposed_subprotocols = _split_comma_header(
|
||||||
|
headers[b'sec-websocket-protocol'])
|
||||||
|
else:
|
||||||
|
proposed_subprotocols = []
|
||||||
|
|
||||||
|
if b'sec-websocket-key' not in headers:
|
||||||
|
return ConnectionFailed(CloseReason.PROTOCOL_ERROR,
|
||||||
|
"Missing Sec-WebSocket-Key header")
|
||||||
|
|
||||||
|
return ConnectionRequested(proposed_subprotocols, event)
|
||||||
|
|
||||||
|
def _extension_accept(self, extensions_header):
|
||||||
|
accepts = {}
|
||||||
|
offers = _split_comma_header(extensions_header)
|
||||||
|
|
||||||
|
for offer in offers:
|
||||||
|
name = offer.split(';', 1)[0].strip()
|
||||||
|
for extension in self.extensions:
|
||||||
|
if extension.name == name:
|
||||||
|
accept = extension.accept(self, offer)
|
||||||
|
if accept is True:
|
||||||
|
accepts[extension.name] = True
|
||||||
|
elif accept is not False and accept is not None:
|
||||||
|
accepts[extension.name] = accept.encode('ascii')
|
||||||
|
|
||||||
|
if accepts:
|
||||||
|
extensions = []
|
||||||
|
for name, params in accepts.items():
|
||||||
|
if params is True:
|
||||||
|
extensions.append(name.encode('ascii'))
|
||||||
|
else:
|
||||||
|
# py34 annoyance: doesn't support bytestring formatting
|
||||||
|
params = params.decode("ascii")
|
||||||
|
extensions.append(('%s; %s' % (name, params))
|
||||||
|
.encode("ascii"))
|
||||||
|
return b', '.join(extensions)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def accept(self, event, subprotocol=None):
|
||||||
|
request = event.h11request
|
||||||
|
request_headers = _normed_header_dict(request.headers)
|
||||||
|
|
||||||
|
nonce = request_headers[b'sec-websocket-key']
|
||||||
|
accept_token = self._generate_accept_token(nonce)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
b"Upgrade": b'WebSocket',
|
||||||
|
b"Connection": b'Upgrade',
|
||||||
|
b"Sec-WebSocket-Accept": accept_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
if subprotocol is not None:
|
||||||
|
if subprotocol not in event.proposed_subprotocols:
|
||||||
|
raise ValueError(
|
||||||
|
"unexpected subprotocol {!r}".format(subprotocol))
|
||||||
|
headers[b'Sec-WebSocket-Protocol'] = subprotocol
|
||||||
|
|
||||||
|
extensions = request_headers.get(b'sec-websocket-extensions', None)
|
||||||
|
if extensions:
|
||||||
|
accepts = self._extension_accept(extensions)
|
||||||
|
if accepts:
|
||||||
|
headers[b"Sec-WebSocket-Extensions"] = accepts
|
||||||
|
|
||||||
|
response = h11.InformationalResponse(status_code=101,
|
||||||
|
headers=headers.items())
|
||||||
|
self._outgoing += self._upgrade_connection.send(response)
|
||||||
|
self._proto = FrameProtocol(self.client, self.extensions)
|
||||||
|
self._state = ConnectionState.OPEN
|
||||||
|
|
||||||
|
def ping(self, payload=None):
|
||||||
|
"""
|
||||||
|
Send a PING message to the peer.
|
||||||
|
|
||||||
|
:param payload: an optional payload to send with the message
|
||||||
|
"""
|
||||||
|
|
||||||
|
payload = bytes(payload or b'')
|
||||||
|
self._outgoing += self._proto.ping(payload)
|
||||||
|
|
||||||
|
def pong(self, payload=None):
|
||||||
|
"""
|
||||||
|
Send a PONG message to the peer.
|
||||||
|
|
||||||
|
This method can be used to send an unsolicted PONG to the peer.
|
||||||
|
It is not needed otherwise since every received PING causes a
|
||||||
|
corresponding PONG to be sent automatically.
|
||||||
|
|
||||||
|
:param payload: an optional payload to send with the message
|
||||||
|
"""
|
||||||
|
|
||||||
|
payload = bytes(payload or b'')
|
||||||
|
self._outgoing += self._proto.pong(payload)
|
81
mitmproxy/contrib/wsproto/events.py
Normal file
81
mitmproxy/contrib/wsproto/events.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
wsproto/events
|
||||||
|
~~~~~~~~~~
|
||||||
|
|
||||||
|
Events that result from processing data on a WebSocket connection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionRequested(object):
|
||||||
|
def __init__(self, proposed_subprotocols, h11request):
|
||||||
|
self.proposed_subprotocols = proposed_subprotocols
|
||||||
|
self.h11request = h11request
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
path = self.h11request.target
|
||||||
|
|
||||||
|
headers = dict(self.h11request.headers)
|
||||||
|
host = headers[b'host']
|
||||||
|
version = headers[b'sec-websocket-version']
|
||||||
|
subprotocol = headers.get(b'sec-websocket-protocol', None)
|
||||||
|
extensions = []
|
||||||
|
|
||||||
|
fmt = '<%s host=%s path=%s version=%s subprotocol=%r extensions=%r>'
|
||||||
|
return fmt % (self.__class__.__name__, host, path, version,
|
||||||
|
subprotocol, extensions)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionEstablished(object):
|
||||||
|
def __init__(self, subprotocol=None, extensions=None):
|
||||||
|
self.subprotocol = subprotocol
|
||||||
|
self.extensions = extensions
|
||||||
|
if self.extensions is None:
|
||||||
|
self.extensions = []
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<ConnectionEstablished subprotocol=%r extensions=%r>' % \
|
||||||
|
(self.subprotocol, self.extensions)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionClosed(object):
|
||||||
|
def __init__(self, code, reason=None):
|
||||||
|
self.code = code
|
||||||
|
self.reason = reason
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '<%s code=%r reason="%s">' % (self.__class__.__name__,
|
||||||
|
self.code, self.reason)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionFailed(ConnectionClosed):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DataReceived(object):
|
||||||
|
def __init__(self, data, frame_finished, message_finished):
|
||||||
|
self.data = data
|
||||||
|
# This has no semantic content, but is provided just in case some
|
||||||
|
# weird edge case user wants to be able to reconstruct the
|
||||||
|
# fragmentation pattern of the original stream. You don't want it:
|
||||||
|
self.frame_finished = frame_finished
|
||||||
|
# This is the field that you almost certainly want:
|
||||||
|
self.message_finished = message_finished
|
||||||
|
|
||||||
|
|
||||||
|
class TextReceived(DataReceived):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BytesReceived(DataReceived):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PingReceived(object):
|
||||||
|
def __init__(self, payload):
|
||||||
|
self.payload = payload
|
||||||
|
|
||||||
|
|
||||||
|
class PongReceived(object):
|
||||||
|
def __init__(self, payload):
|
||||||
|
self.payload = payload
|
257
mitmproxy/contrib/wsproto/extensions.py
Normal file
257
mitmproxy/contrib/wsproto/extensions.py
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
wsproto/extensions
|
||||||
|
~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
WebSocket extensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import zlib
|
||||||
|
|
||||||
|
from .frame_protocol import CloseReason, Opcode, RsvBits
|
||||||
|
|
||||||
|
|
||||||
|
class Extension(object):
|
||||||
|
name = None
|
||||||
|
|
||||||
|
def enabled(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def offer(self, connection):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def accept(self, connection, offer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finalize(self, connection, offer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def frame_inbound_header(self, proto, opcode, rsv, payload_length):
|
||||||
|
return RsvBits(False, False, False)
|
||||||
|
|
||||||
|
def frame_inbound_payload_data(self, proto, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
def frame_inbound_complete(self, proto, fin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def frame_outbound(self, proto, opcode, rsv, data, fin):
|
||||||
|
return (rsv, data)
|
||||||
|
|
||||||
|
|
||||||
|
class PerMessageDeflate(Extension):
|
||||||
|
name = 'permessage-deflate'
|
||||||
|
|
||||||
|
DEFAULT_CLIENT_MAX_WINDOW_BITS = 15
|
||||||
|
DEFAULT_SERVER_MAX_WINDOW_BITS = 15
|
||||||
|
|
||||||
|
def __init__(self, client_no_context_takeover=False,
|
||||||
|
client_max_window_bits=None, server_no_context_takeover=False,
|
||||||
|
server_max_window_bits=None):
|
||||||
|
self.client_no_context_takeover = client_no_context_takeover
|
||||||
|
if client_max_window_bits is None:
|
||||||
|
client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS
|
||||||
|
self.client_max_window_bits = client_max_window_bits
|
||||||
|
self.server_no_context_takeover = server_no_context_takeover
|
||||||
|
if server_max_window_bits is None:
|
||||||
|
server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS
|
||||||
|
self.server_max_window_bits = server_max_window_bits
|
||||||
|
|
||||||
|
self._compressor = None
|
||||||
|
self._decompressor = None
|
||||||
|
# This refers to the current frame
|
||||||
|
self._inbound_is_compressible = None
|
||||||
|
# This refers to the ongoing message (which might span multiple
|
||||||
|
# frames). Only the first frame in a fragmented message is flagged for
|
||||||
|
# compression, so this carries that bit forward.
|
||||||
|
self._inbound_compressed = None
|
||||||
|
|
||||||
|
self._enabled = False
|
||||||
|
|
||||||
|
def _compressible_opcode(self, opcode):
|
||||||
|
return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION)
|
||||||
|
|
||||||
|
def enabled(self):
|
||||||
|
return self._enabled
|
||||||
|
|
||||||
|
def offer(self, connection):
|
||||||
|
parameters = [
|
||||||
|
'client_max_window_bits=%d' % self.client_max_window_bits,
|
||||||
|
'server_max_window_bits=%d' % self.server_max_window_bits,
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.client_no_context_takeover:
|
||||||
|
parameters.append('client_no_context_takeover')
|
||||||
|
if self.server_no_context_takeover:
|
||||||
|
parameters.append('server_no_context_takeover')
|
||||||
|
|
||||||
|
return '; '.join(parameters)
|
||||||
|
|
||||||
|
def finalize(self, connection, offer):
|
||||||
|
bits = [b.strip() for b in offer.split(';')]
|
||||||
|
for bit in bits[1:]:
|
||||||
|
if bit.startswith('client_no_context_takeover'):
|
||||||
|
self.client_no_context_takeover = True
|
||||||
|
elif bit.startswith('server_no_context_takeover'):
|
||||||
|
self.server_no_context_takeover = True
|
||||||
|
elif bit.startswith('client_max_window_bits'):
|
||||||
|
self.client_max_window_bits = int(bit.split('=', 1)[1].strip())
|
||||||
|
elif bit.startswith('server_max_window_bits'):
|
||||||
|
self.server_max_window_bits = int(bit.split('=', 1)[1].strip())
|
||||||
|
|
||||||
|
self._enabled = True
|
||||||
|
|
||||||
|
def _parse_params(self, params):
|
||||||
|
client_max_window_bits = None
|
||||||
|
server_max_window_bits = None
|
||||||
|
|
||||||
|
bits = [b.strip() for b in params.split(';')]
|
||||||
|
for bit in bits[1:]:
|
||||||
|
if bit.startswith('client_no_context_takeover'):
|
||||||
|
self.client_no_context_takeover = True
|
||||||
|
elif bit.startswith('server_no_context_takeover'):
|
||||||
|
self.server_no_context_takeover = True
|
||||||
|
elif bit.startswith('client_max_window_bits'):
|
||||||
|
if '=' in bit:
|
||||||
|
client_max_window_bits = int(bit.split('=', 1)[1].strip())
|
||||||
|
else:
|
||||||
|
client_max_window_bits = self.client_max_window_bits
|
||||||
|
elif bit.startswith('server_max_window_bits'):
|
||||||
|
if '=' in bit:
|
||||||
|
server_max_window_bits = int(bit.split('=', 1)[1].strip())
|
||||||
|
else:
|
||||||
|
server_max_window_bits = self.server_max_window_bits
|
||||||
|
|
||||||
|
return client_max_window_bits, server_max_window_bits
|
||||||
|
|
||||||
|
def accept(self, connection, offer):
|
||||||
|
client_max_window_bits, server_max_window_bits = \
|
||||||
|
self._parse_params(offer)
|
||||||
|
|
||||||
|
self._enabled = True
|
||||||
|
|
||||||
|
parameters = []
|
||||||
|
|
||||||
|
if self.client_no_context_takeover:
|
||||||
|
parameters.append('client_no_context_takeover')
|
||||||
|
if client_max_window_bits is not None:
|
||||||
|
parameters.append('client_max_window_bits=%d' %
|
||||||
|
client_max_window_bits)
|
||||||
|
self.client_max_window_bits = client_max_window_bits
|
||||||
|
if self.server_no_context_takeover:
|
||||||
|
parameters.append('server_no_context_takeover')
|
||||||
|
if server_max_window_bits is not None:
|
||||||
|
parameters.append('server_max_window_bits=%d' %
|
||||||
|
server_max_window_bits)
|
||||||
|
self.server_max_window_bits = server_max_window_bits
|
||||||
|
|
||||||
|
return '; '.join(parameters)
|
||||||
|
|
||||||
|
def frame_inbound_header(self, proto, opcode, rsv, payload_length):
|
||||||
|
if rsv.rsv1 and opcode.iscontrol():
|
||||||
|
return CloseReason.PROTOCOL_ERROR
|
||||||
|
elif rsv.rsv1 and opcode is Opcode.CONTINUATION:
|
||||||
|
return CloseReason.PROTOCOL_ERROR
|
||||||
|
|
||||||
|
self._inbound_is_compressible = self._compressible_opcode(opcode)
|
||||||
|
|
||||||
|
if self._inbound_compressed is None:
|
||||||
|
self._inbound_compressed = rsv.rsv1
|
||||||
|
if self._inbound_compressed:
|
||||||
|
assert self._inbound_is_compressible
|
||||||
|
if proto.client:
|
||||||
|
bits = self.server_max_window_bits
|
||||||
|
else:
|
||||||
|
bits = self.client_max_window_bits
|
||||||
|
if self._decompressor is None:
|
||||||
|
self._decompressor = zlib.decompressobj(-int(bits))
|
||||||
|
|
||||||
|
return RsvBits(True, False, False)
|
||||||
|
|
||||||
|
def frame_inbound_payload_data(self, proto, data):
|
||||||
|
if not self._inbound_compressed or not self._inbound_is_compressible:
|
||||||
|
return data
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._decompressor.decompress(bytes(data))
|
||||||
|
except zlib.error:
|
||||||
|
return CloseReason.INVALID_FRAME_PAYLOAD_DATA
|
||||||
|
|
||||||
|
def frame_inbound_complete(self, proto, fin):
|
||||||
|
if not fin:
|
||||||
|
return
|
||||||
|
elif not self._inbound_is_compressible:
|
||||||
|
return
|
||||||
|
elif not self._inbound_compressed:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = self._decompressor.decompress(b'\x00\x00\xff\xff')
|
||||||
|
data += self._decompressor.flush()
|
||||||
|
except zlib.error:
|
||||||
|
return CloseReason.INVALID_FRAME_PAYLOAD_DATA
|
||||||
|
|
||||||
|
if proto.client:
|
||||||
|
no_context_takeover = self.server_no_context_takeover
|
||||||
|
else:
|
||||||
|
no_context_takeover = self.client_no_context_takeover
|
||||||
|
|
||||||
|
if no_context_takeover:
|
||||||
|
self._decompressor = None
|
||||||
|
|
||||||
|
self._inbound_compressed = None
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def frame_outbound(self, proto, opcode, rsv, data, fin):
|
||||||
|
if not self._compressible_opcode(opcode):
|
||||||
|
return (rsv, data)
|
||||||
|
|
||||||
|
if opcode is not Opcode.CONTINUATION:
|
||||||
|
rsv = RsvBits(True, *rsv[1:])
|
||||||
|
|
||||||
|
if self._compressor is None:
|
||||||
|
assert opcode is not Opcode.CONTINUATION
|
||||||
|
if proto.client:
|
||||||
|
bits = self.client_max_window_bits
|
||||||
|
else:
|
||||||
|
bits = self.server_max_window_bits
|
||||||
|
self._compressor = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION,
|
||||||
|
zlib.DEFLATED, -int(bits))
|
||||||
|
|
||||||
|
data = self._compressor.compress(bytes(data))
|
||||||
|
|
||||||
|
if fin:
|
||||||
|
data += self._compressor.flush(zlib.Z_SYNC_FLUSH)
|
||||||
|
data = data[:-4]
|
||||||
|
|
||||||
|
if proto.client:
|
||||||
|
no_context_takeover = self.client_no_context_takeover
|
||||||
|
else:
|
||||||
|
no_context_takeover = self.server_no_context_takeover
|
||||||
|
|
||||||
|
if no_context_takeover:
|
||||||
|
self._compressor = None
|
||||||
|
|
||||||
|
return (rsv, data)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
descr = ['client_max_window_bits=%d' % self.client_max_window_bits]
|
||||||
|
if self.client_no_context_takeover:
|
||||||
|
descr.append('client_no_context_takeover')
|
||||||
|
descr.append('server_max_window_bits=%d' % self.server_max_window_bits)
|
||||||
|
if self.server_no_context_takeover:
|
||||||
|
descr.append('server_no_context_takeover')
|
||||||
|
|
||||||
|
descr = '; '.join(descr)
|
||||||
|
|
||||||
|
return '<%s %s>' % (self.__class__.__name__, descr)
|
||||||
|
|
||||||
|
|
||||||
|
#: SUPPORTED_EXTENSIONS maps all supported extension names to their class.
|
||||||
|
#: This can be used to iterate all supported extensions of wsproto, instantiate
|
||||||
|
#: new extensions based on their name, or check if a given extension is
|
||||||
|
#: supported or not.
|
||||||
|
SUPPORTED_EXTENSIONS = {
|
||||||
|
PerMessageDeflate.name: PerMessageDeflate
|
||||||
|
}
|
579
mitmproxy/contrib/wsproto/frame_protocol.py
Normal file
579
mitmproxy/contrib/wsproto/frame_protocol.py
Normal file
@ -0,0 +1,579 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
wsproto/frame_protocol
|
||||||
|
~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
WebSocket frame protocol implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import itertools
|
||||||
|
import struct
|
||||||
|
from codecs import getincrementaldecoder
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
from enum import Enum, IntEnum
|
||||||
|
|
||||||
|
from .compat import unicode, Utf8Validator
|
||||||
|
|
||||||
|
try:
|
||||||
|
from wsaccel.xormask import XorMaskerSimple
|
||||||
|
except ImportError:
|
||||||
|
class XorMaskerSimple:
|
||||||
|
def __init__(self, masking_key):
|
||||||
|
self._maskbytes = itertools.cycle(bytearray(masking_key))
|
||||||
|
|
||||||
|
def process(self, data):
|
||||||
|
maskbytes = self._maskbytes
|
||||||
|
return bytearray(b ^ next(maskbytes) for b in bytearray(data))
|
||||||
|
|
||||||
|
|
||||||
|
class XorMaskerNull:
|
||||||
|
def process(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# RFC6455, Section 5.2 - Base Framing Protocol
|
||||||
|
|
||||||
|
# Payload length constants
|
||||||
|
PAYLOAD_LENGTH_TWO_BYTE = 126
|
||||||
|
PAYLOAD_LENGTH_EIGHT_BYTE = 127
|
||||||
|
MAX_PAYLOAD_NORMAL = 125
|
||||||
|
MAX_PAYLOAD_TWO_BYTE = 2 ** 16 - 1
|
||||||
|
MAX_PAYLOAD_EIGHT_BYTE = 2 ** 64 - 1
|
||||||
|
MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE
|
||||||
|
|
||||||
|
# MASK and PAYLOAD LEN are packed into a byte
|
||||||
|
MASK_MASK = 0x80
|
||||||
|
PAYLOAD_LEN_MASK = 0x7f
|
||||||
|
|
||||||
|
# FIN, RSV[123] and OPCODE are packed into a single byte
|
||||||
|
FIN_MASK = 0x80
|
||||||
|
RSV1_MASK = 0x40
|
||||||
|
RSV2_MASK = 0x20
|
||||||
|
RSV3_MASK = 0x10
|
||||||
|
OPCODE_MASK = 0x0f
|
||||||
|
|
||||||
|
|
||||||
|
class Opcode(IntEnum):
|
||||||
|
"""
|
||||||
|
RFC 6455, Section 5.2 - Base Framing Protocol
|
||||||
|
"""
|
||||||
|
CONTINUATION = 0x0
|
||||||
|
TEXT = 0x1
|
||||||
|
BINARY = 0x2
|
||||||
|
CLOSE = 0x8
|
||||||
|
PING = 0x9
|
||||||
|
PONG = 0xA
|
||||||
|
|
||||||
|
def iscontrol(self):
|
||||||
|
return bool(self & 0x08)
|
||||||
|
|
||||||
|
|
||||||
|
class CloseReason(IntEnum):
|
||||||
|
"""
|
||||||
|
RFC 6455, Section 7.4.1 - Defined Status Codes
|
||||||
|
"""
|
||||||
|
NORMAL_CLOSURE = 1000
|
||||||
|
GOING_AWAY = 1001
|
||||||
|
PROTOCOL_ERROR = 1002
|
||||||
|
UNSUPPORTED_DATA = 1003
|
||||||
|
NO_STATUS_RCVD = 1005
|
||||||
|
ABNORMAL_CLOSURE = 1006
|
||||||
|
INVALID_FRAME_PAYLOAD_DATA = 1007
|
||||||
|
POLICY_VIOLATION = 1008
|
||||||
|
MESSAGE_TOO_BIG = 1009
|
||||||
|
MANDATORY_EXT = 1010
|
||||||
|
INTERNAL_ERROR = 1011
|
||||||
|
SERVICE_RESTART = 1012
|
||||||
|
TRY_AGAIN_LATER = 1013
|
||||||
|
TLS_HANDSHAKE_FAILED = 1015
|
||||||
|
|
||||||
|
|
||||||
|
# RFC 6455, Section 7.4.1 - Defined Status Codes
|
||||||
|
LOCAL_ONLY_CLOSE_REASONS = (
|
||||||
|
CloseReason.NO_STATUS_RCVD,
|
||||||
|
CloseReason.ABNORMAL_CLOSURE,
|
||||||
|
CloseReason.TLS_HANDSHAKE_FAILED,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# RFC 6455, Section 7.4.2 - Status Code Ranges
|
||||||
|
MIN_CLOSE_REASON = 1000
|
||||||
|
MIN_PROTOCOL_CLOSE_REASON = 1000
|
||||||
|
MAX_PROTOCOL_CLOSE_REASON = 2999
|
||||||
|
MIN_LIBRARY_CLOSE_REASON = 3000
|
||||||
|
MAX_LIBRARY_CLOSE_REASON = 3999
|
||||||
|
MIN_PRIVATE_CLOSE_REASON = 4000
|
||||||
|
MAX_PRIVATE_CLOSE_REASON = 4999
|
||||||
|
MAX_CLOSE_REASON = 4999
|
||||||
|
|
||||||
|
|
||||||
|
NULL_MASK = struct.pack("!I", 0)
|
||||||
|
|
||||||
|
|
||||||
|
class ParseFailed(Exception):
|
||||||
|
def __init__(self, msg, code=CloseReason.PROTOCOL_ERROR):
|
||||||
|
super(ParseFailed, self).__init__(msg)
|
||||||
|
self.code = code
|
||||||
|
|
||||||
|
|
||||||
|
Header = namedtuple("Header", "fin rsv opcode payload_len masking_key".split())
|
||||||
|
|
||||||
|
|
||||||
|
Frame = namedtuple("Frame",
|
||||||
|
"opcode payload frame_finished message_finished".split())
|
||||||
|
|
||||||
|
|
||||||
|
RsvBits = namedtuple("RsvBits", "rsv1 rsv2 rsv3".split())
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_utf8(data, nbytes):
|
||||||
|
if len(data) <= nbytes:
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Truncate
|
||||||
|
data = data[:nbytes]
|
||||||
|
# But we might have cut a codepoint in half, in which case we want to
|
||||||
|
# discard the partial character so the data is at least
|
||||||
|
# well-formed. This is a little inefficient since it processes the
|
||||||
|
# whole message twice when in theory we could just peek at the last
|
||||||
|
# few characters, but since this is only used for close messages (max
|
||||||
|
# length = 125 bytes) it really doesn't matter.
|
||||||
|
data = data.decode("utf-8", errors="ignore").encode("utf-8")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class Buffer(object):
|
||||||
|
def __init__(self, initial_bytes=None):
|
||||||
|
self.buffer = bytearray()
|
||||||
|
self.bytes_used = 0
|
||||||
|
if initial_bytes:
|
||||||
|
self.feed(initial_bytes)
|
||||||
|
|
||||||
|
def feed(self, new_bytes):
|
||||||
|
self.buffer += new_bytes
|
||||||
|
|
||||||
|
def consume_at_most(self, nbytes):
|
||||||
|
if not nbytes:
|
||||||
|
return bytearray()
|
||||||
|
|
||||||
|
data = self.buffer[self.bytes_used:self.bytes_used + nbytes]
|
||||||
|
self.bytes_used += len(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def consume_exactly(self, nbytes):
|
||||||
|
if len(self.buffer) - self.bytes_used < nbytes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.consume_at_most(nbytes)
|
||||||
|
|
||||||
|
def commit(self):
|
||||||
|
# In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic
|
||||||
|
del self.buffer[:self.bytes_used]
|
||||||
|
self.bytes_used = 0
|
||||||
|
|
||||||
|
def rollback(self):
|
||||||
|
self.bytes_used = 0
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.buffer)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageDecoder(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.opcode = None
|
||||||
|
self.validator = None
|
||||||
|
self.decoder = None
|
||||||
|
|
||||||
|
def process_frame(self, frame):
|
||||||
|
assert not frame.opcode.iscontrol()
|
||||||
|
|
||||||
|
if self.opcode is None:
|
||||||
|
if frame.opcode is Opcode.CONTINUATION:
|
||||||
|
raise ParseFailed("unexpected CONTINUATION")
|
||||||
|
self.opcode = frame.opcode
|
||||||
|
elif frame.opcode is not Opcode.CONTINUATION:
|
||||||
|
raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode)
|
||||||
|
|
||||||
|
if frame.opcode is Opcode.TEXT:
|
||||||
|
self.validator = Utf8Validator()
|
||||||
|
self.decoder = getincrementaldecoder("utf-8")()
|
||||||
|
|
||||||
|
finished = frame.frame_finished and frame.message_finished
|
||||||
|
|
||||||
|
if self.decoder is not None:
|
||||||
|
data = self.decode_payload(frame.payload, finished)
|
||||||
|
else:
|
||||||
|
data = frame.payload
|
||||||
|
|
||||||
|
frame = Frame(self.opcode, data, frame.frame_finished, finished)
|
||||||
|
|
||||||
|
if finished:
|
||||||
|
self.opcode = None
|
||||||
|
self.decoder = None
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def decode_payload(self, data, finished):
|
||||||
|
if self.validator is not None:
|
||||||
|
results = self.validator.validate(bytes(data))
|
||||||
|
if not results[0] or (finished and not results[1]):
|
||||||
|
raise ParseFailed(u'encountered invalid UTF-8 while processing'
|
||||||
|
' text message at payload octet index %d' %
|
||||||
|
results[3],
|
||||||
|
CloseReason.INVALID_FRAME_PAYLOAD_DATA)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.decoder.decode(data, finished)
|
||||||
|
except UnicodeDecodeError as exc:
|
||||||
|
raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameDecoder(object):
|
||||||
|
def __init__(self, client, extensions=None):
|
||||||
|
self.client = client
|
||||||
|
self.extensions = extensions or []
|
||||||
|
|
||||||
|
self.buffer = Buffer()
|
||||||
|
|
||||||
|
self.header = None
|
||||||
|
self.effective_opcode = None
|
||||||
|
self.masker = None
|
||||||
|
self.payload_required = 0
|
||||||
|
self.payload_consumed = 0
|
||||||
|
|
||||||
|
def receive_bytes(self, data):
|
||||||
|
self.buffer.feed(data)
|
||||||
|
|
||||||
|
def process_buffer(self):
|
||||||
|
if not self.header:
|
||||||
|
if not self.parse_header():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(self.buffer) < self.payload_required:
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload_remaining = self.header.payload_len - self.payload_consumed
|
||||||
|
payload = self.buffer.consume_at_most(payload_remaining)
|
||||||
|
if not payload and self.header.payload_len > 0:
|
||||||
|
return None
|
||||||
|
self.buffer.commit()
|
||||||
|
|
||||||
|
self.payload_consumed += len(payload)
|
||||||
|
finished = self.payload_consumed == self.header.payload_len
|
||||||
|
|
||||||
|
payload = self.masker.process(payload)
|
||||||
|
|
||||||
|
for extension in self.extensions:
|
||||||
|
payload = extension.frame_inbound_payload_data(self, payload)
|
||||||
|
if isinstance(payload, CloseReason):
|
||||||
|
raise ParseFailed("error in extension", payload)
|
||||||
|
|
||||||
|
if finished:
|
||||||
|
final = bytearray()
|
||||||
|
for extension in self.extensions:
|
||||||
|
result = extension.frame_inbound_complete(self,
|
||||||
|
self.header.fin)
|
||||||
|
if isinstance(result, CloseReason):
|
||||||
|
raise ParseFailed("error in extension", result)
|
||||||
|
if result is not None:
|
||||||
|
final += result
|
||||||
|
payload += final
|
||||||
|
|
||||||
|
frame = Frame(self.effective_opcode, payload, finished,
|
||||||
|
self.header.fin)
|
||||||
|
|
||||||
|
if finished:
|
||||||
|
self.header = None
|
||||||
|
self.effective_opcode = None
|
||||||
|
self.masker = None
|
||||||
|
else:
|
||||||
|
self.effective_opcode = Opcode.CONTINUATION
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def parse_header(self):
|
||||||
|
data = self.buffer.consume_exactly(2)
|
||||||
|
if data is None:
|
||||||
|
self.buffer.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
fin = bool(data[0] & FIN_MASK)
|
||||||
|
rsv = RsvBits(bool(data[0] & RSV1_MASK),
|
||||||
|
bool(data[0] & RSV2_MASK),
|
||||||
|
bool(data[0] & RSV3_MASK))
|
||||||
|
opcode = data[0] & OPCODE_MASK
|
||||||
|
try:
|
||||||
|
opcode = Opcode(opcode)
|
||||||
|
except ValueError:
|
||||||
|
raise ParseFailed("Invalid opcode {:#x}".format(opcode))
|
||||||
|
|
||||||
|
if opcode.iscontrol() and not fin:
|
||||||
|
raise ParseFailed("Invalid attempt to fragment control frame")
|
||||||
|
|
||||||
|
has_mask = bool(data[1] & MASK_MASK)
|
||||||
|
payload_len = data[1] & PAYLOAD_LEN_MASK
|
||||||
|
payload_len = self.parse_extended_payload_length(opcode, payload_len)
|
||||||
|
if payload_len is None:
|
||||||
|
self.buffer.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.extension_processing(opcode, rsv, payload_len)
|
||||||
|
|
||||||
|
if has_mask and self.client:
|
||||||
|
raise ParseFailed("client received unexpected masked frame")
|
||||||
|
if not has_mask and not self.client:
|
||||||
|
raise ParseFailed("server received unexpected unmasked frame")
|
||||||
|
if has_mask:
|
||||||
|
masking_key = self.buffer.consume_exactly(4)
|
||||||
|
if masking_key is None:
|
||||||
|
self.buffer.rollback()
|
||||||
|
return False
|
||||||
|
self.masker = XorMaskerSimple(masking_key)
|
||||||
|
else:
|
||||||
|
self.masker = XorMaskerNull()
|
||||||
|
|
||||||
|
self.buffer.commit()
|
||||||
|
self.header = Header(fin, rsv, opcode, payload_len, None)
|
||||||
|
self.effective_opcode = self.header.opcode
|
||||||
|
if self.header.opcode.iscontrol():
|
||||||
|
self.payload_required = payload_len
|
||||||
|
else:
|
||||||
|
self.payload_required = 0
|
||||||
|
self.payload_consumed = 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
def parse_extended_payload_length(self, opcode, payload_len):
|
||||||
|
if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL:
|
||||||
|
raise ParseFailed("Control frame with payload len > 125")
|
||||||
|
if payload_len == PAYLOAD_LENGTH_TWO_BYTE:
|
||||||
|
data = self.buffer.consume_exactly(2)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
(payload_len,) = struct.unpack("!H", data)
|
||||||
|
if payload_len <= MAX_PAYLOAD_NORMAL:
|
||||||
|
raise ParseFailed(
|
||||||
|
"Payload length used 2 bytes when 1 would have sufficed")
|
||||||
|
elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE:
|
||||||
|
data = self.buffer.consume_exactly(8)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
(payload_len,) = struct.unpack("!Q", data)
|
||||||
|
if payload_len <= MAX_PAYLOAD_TWO_BYTE:
|
||||||
|
raise ParseFailed(
|
||||||
|
"Payload length used 8 bytes when 2 would have sufficed")
|
||||||
|
if payload_len >> 63:
|
||||||
|
# I'm not sure why this is illegal, but that's what the RFC
|
||||||
|
# says, so...
|
||||||
|
raise ParseFailed("8-byte payload length with non-zero MSB")
|
||||||
|
|
||||||
|
return payload_len
|
||||||
|
|
||||||
|
def extension_processing(self, opcode, rsv, payload_len):
|
||||||
|
rsv_used = [False, False, False]
|
||||||
|
for extension in self.extensions:
|
||||||
|
result = extension.frame_inbound_header(self, opcode, rsv,
|
||||||
|
payload_len)
|
||||||
|
if isinstance(result, CloseReason):
|
||||||
|
raise ParseFailed("error in extension", result)
|
||||||
|
for bit, used in enumerate(result):
|
||||||
|
if used:
|
||||||
|
rsv_used[bit] = True
|
||||||
|
for expected, found in zip(rsv_used, rsv):
|
||||||
|
if found and not expected:
|
||||||
|
raise ParseFailed("Reserved bit set unexpectedly")
|
||||||
|
|
||||||
|
|
||||||
|
class FrameProtocol(object):
|
||||||
|
class State(Enum):
|
||||||
|
HEADER = 1
|
||||||
|
PAYLOAD = 2
|
||||||
|
FRAME_COMPLETE = 3
|
||||||
|
FAILED = 4
|
||||||
|
|
||||||
|
def __init__(self, client, extensions):
|
||||||
|
self.client = client
|
||||||
|
self.extensions = [ext for ext in extensions if ext.enabled()]
|
||||||
|
|
||||||
|
# Global state
|
||||||
|
self._frame_decoder = FrameDecoder(self.client, self.extensions)
|
||||||
|
self._message_decoder = MessageDecoder()
|
||||||
|
self._parse_more = self.parse_more_gen()
|
||||||
|
|
||||||
|
self._outbound_opcode = None
|
||||||
|
|
||||||
|
def _process_close(self, frame):
|
||||||
|
data = frame.payload
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
# "If this Close control frame contains no status code, _The
|
||||||
|
# WebSocket Connection Close Code_ is considered to be 1005"
|
||||||
|
data = (CloseReason.NO_STATUS_RCVD, "")
|
||||||
|
elif len(data) == 1:
|
||||||
|
raise ParseFailed("CLOSE with 1 byte payload")
|
||||||
|
else:
|
||||||
|
(code,) = struct.unpack("!H", data[:2])
|
||||||
|
if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON:
|
||||||
|
raise ParseFailed("CLOSE with invalid code")
|
||||||
|
try:
|
||||||
|
code = CloseReason(code)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
if code in LOCAL_ONLY_CLOSE_REASONS:
|
||||||
|
raise ParseFailed(
|
||||||
|
"remote CLOSE with local-only reason")
|
||||||
|
if not isinstance(code, CloseReason) and \
|
||||||
|
code <= MAX_PROTOCOL_CLOSE_REASON:
|
||||||
|
raise ParseFailed(
|
||||||
|
"CLOSE with unknown reserved code")
|
||||||
|
validator = Utf8Validator()
|
||||||
|
if validator is not None:
|
||||||
|
results = validator.validate(bytes(data[2:]))
|
||||||
|
if not (results[0] and results[1]):
|
||||||
|
raise ParseFailed(u'encountered invalid UTF-8 while'
|
||||||
|
' processing close message at payload'
|
||||||
|
' octet index %d' %
|
||||||
|
results[3],
|
||||||
|
CloseReason.INVALID_FRAME_PAYLOAD_DATA)
|
||||||
|
try:
|
||||||
|
reason = data[2:].decode("utf-8")
|
||||||
|
except UnicodeDecodeError as exc:
|
||||||
|
raise ParseFailed(
|
||||||
|
"Error decoding CLOSE reason: " + str(exc),
|
||||||
|
CloseReason.INVALID_FRAME_PAYLOAD_DATA)
|
||||||
|
data = (code, reason)
|
||||||
|
|
||||||
|
return Frame(frame.opcode, data, frame.frame_finished,
|
||||||
|
frame.message_finished)
|
||||||
|
|
||||||
|
def parse_more_gen(self):
|
||||||
|
# Consume as much as we can from self._buffer, yielding events, and
|
||||||
|
# then yield None when we need more data. Or raise ParseFailed.
|
||||||
|
|
||||||
|
# XX FIXME this should probably be refactored so that we never see
|
||||||
|
# disabled extensions in the first place...
|
||||||
|
self.extensions = [ext for ext in self.extensions if ext.enabled()]
|
||||||
|
closed = False
|
||||||
|
|
||||||
|
while not closed:
|
||||||
|
frame = self._frame_decoder.process_buffer()
|
||||||
|
|
||||||
|
if frame is not None:
|
||||||
|
if not frame.opcode.iscontrol():
|
||||||
|
frame = self._message_decoder.process_frame(frame)
|
||||||
|
elif frame.opcode == Opcode.CLOSE:
|
||||||
|
frame = self._process_close(frame)
|
||||||
|
closed = True
|
||||||
|
|
||||||
|
yield frame
|
||||||
|
|
||||||
|
def receive_bytes(self, data):
|
||||||
|
self._frame_decoder.receive_bytes(data)
|
||||||
|
|
||||||
|
def received_frames(self):
|
||||||
|
for event in self._parse_more:
|
||||||
|
if event is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
def close(self, code=None, reason=None):
|
||||||
|
payload = bytearray()
|
||||||
|
if code is None and reason is not None:
|
||||||
|
raise TypeError("cannot specify a reason without a code")
|
||||||
|
if code in LOCAL_ONLY_CLOSE_REASONS:
|
||||||
|
code = CloseReason.NORMAL_CLOSURE
|
||||||
|
if code is not None:
|
||||||
|
payload += bytearray(struct.pack('!H', code))
|
||||||
|
if reason is not None:
|
||||||
|
payload += _truncate_utf8(reason.encode('utf-8'),
|
||||||
|
MAX_PAYLOAD_NORMAL - 2)
|
||||||
|
|
||||||
|
return self._serialize_frame(Opcode.CLOSE, payload)
|
||||||
|
|
||||||
|
def ping(self, payload=b''):
|
||||||
|
return self._serialize_frame(Opcode.PING, payload)
|
||||||
|
|
||||||
|
def pong(self, payload=b''):
|
||||||
|
return self._serialize_frame(Opcode.PONG, payload)
|
||||||
|
|
||||||
|
def send_data(self, payload=b'', fin=True):
|
||||||
|
if isinstance(payload, (bytes, bytearray, memoryview)):
|
||||||
|
opcode = Opcode.BINARY
|
||||||
|
elif isinstance(payload, unicode):
|
||||||
|
opcode = Opcode.TEXT
|
||||||
|
payload = payload.encode('utf-8')
|
||||||
|
else:
|
||||||
|
raise ValueError('Must provide bytes or text')
|
||||||
|
|
||||||
|
if self._outbound_opcode is None:
|
||||||
|
self._outbound_opcode = opcode
|
||||||
|
elif self._outbound_opcode is not opcode:
|
||||||
|
raise TypeError('Data type mismatch inside message')
|
||||||
|
else:
|
||||||
|
opcode = Opcode.CONTINUATION
|
||||||
|
|
||||||
|
if fin:
|
||||||
|
self._outbound_opcode = None
|
||||||
|
|
||||||
|
return self._serialize_frame(opcode, payload, fin)
|
||||||
|
|
||||||
|
def _make_fin_rsv_opcode(self, fin, rsv, opcode):
|
||||||
|
fin = int(fin) << 7
|
||||||
|
rsv = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + \
|
||||||
|
(int(rsv.rsv3) << 4)
|
||||||
|
opcode = int(opcode)
|
||||||
|
|
||||||
|
return fin | rsv | opcode
|
||||||
|
|
||||||
|
def _serialize_frame(self, opcode, payload=b'', fin=True):
|
||||||
|
rsv = RsvBits(False, False, False)
|
||||||
|
for extension in reversed(self.extensions):
|
||||||
|
rsv, payload = extension.frame_outbound(self, opcode, rsv, payload,
|
||||||
|
fin)
|
||||||
|
|
||||||
|
fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode)
|
||||||
|
|
||||||
|
payload_length = len(payload)
|
||||||
|
quad_payload = False
|
||||||
|
if payload_length <= MAX_PAYLOAD_NORMAL:
|
||||||
|
first_payload = payload_length
|
||||||
|
second_payload = None
|
||||||
|
elif payload_length <= MAX_PAYLOAD_TWO_BYTE:
|
||||||
|
first_payload = PAYLOAD_LENGTH_TWO_BYTE
|
||||||
|
second_payload = payload_length
|
||||||
|
else:
|
||||||
|
first_payload = PAYLOAD_LENGTH_EIGHT_BYTE
|
||||||
|
second_payload = payload_length
|
||||||
|
quad_payload = True
|
||||||
|
|
||||||
|
if self.client:
|
||||||
|
first_payload |= 1 << 7
|
||||||
|
|
||||||
|
header = bytearray([fin_rsv_opcode, first_payload])
|
||||||
|
if second_payload is not None:
|
||||||
|
if opcode.iscontrol():
|
||||||
|
raise ValueError("payload too long for control frame")
|
||||||
|
if quad_payload:
|
||||||
|
header += bytearray(struct.pack('!Q', second_payload))
|
||||||
|
else:
|
||||||
|
header += bytearray(struct.pack('!H', second_payload))
|
||||||
|
|
||||||
|
if self.client:
|
||||||
|
# "The masking key is a 32-bit value chosen at random by the
|
||||||
|
# client. When preparing a masked frame, the client MUST pick a
|
||||||
|
# fresh masking key from the set of allowed 32-bit values. The
|
||||||
|
# masking key needs to be unpredictable; thus, the masking key
|
||||||
|
# MUST be derived from a strong source of entropy, and the masking
|
||||||
|
# key for a given frame MUST NOT make it simple for a server/proxy
|
||||||
|
# to predict the masking key for a subsequent frame. The
|
||||||
|
# unpredictability of the masking key is essential to prevent
|
||||||
|
# authors of malicious applications from selecting the bytes that
|
||||||
|
# appear on the wire."
|
||||||
|
# -- https://tools.ietf.org/html/rfc6455#section-5.3
|
||||||
|
masking_key = os.urandom(4)
|
||||||
|
masker = XorMaskerSimple(masking_key)
|
||||||
|
return header + masking_key + masker.process(payload)
|
||||||
|
|
||||||
|
return header + payload
|
@ -1,14 +1,18 @@
|
|||||||
import os
|
|
||||||
import socket
|
import socket
|
||||||
import struct
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
|
from mitmproxy.contrib.wsproto import events
|
||||||
|
from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection
|
||||||
|
from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate
|
||||||
|
from mitmproxy.contrib.wsproto.frame_protocol import Opcode
|
||||||
|
|
||||||
from mitmproxy import exceptions
|
from mitmproxy import exceptions
|
||||||
from mitmproxy import flow
|
from mitmproxy import flow
|
||||||
from mitmproxy.proxy.protocol import base
|
from mitmproxy.proxy.protocol import base
|
||||||
from mitmproxy.net import tcp
|
from mitmproxy.net import tcp
|
||||||
from mitmproxy.net import websockets
|
from mitmproxy.net import websockets
|
||||||
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
|
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
|
||||||
|
from mitmproxy.utils import strutils
|
||||||
|
|
||||||
|
|
||||||
class WebSocketLayer(base.Layer):
|
class WebSocketLayer(base.Layer):
|
||||||
@ -44,26 +48,56 @@ class WebSocketLayer(base.Layer):
|
|||||||
self.client_frame_buffer = []
|
self.client_frame_buffer = []
|
||||||
self.server_frame_buffer = []
|
self.server_frame_buffer = []
|
||||||
|
|
||||||
def _handle_frame(self, frame, source_conn, other_conn, is_server):
|
self.connections = {} # type: Dict[object, WSConnection]
|
||||||
if frame.header.opcode & 0x8 == 0:
|
|
||||||
return self._handle_data_frame(frame, source_conn, other_conn, is_server)
|
|
||||||
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
|
|
||||||
return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
|
|
||||||
elif frame.header.opcode == websockets.OPCODE.CLOSE:
|
|
||||||
return self._handle_close(frame, source_conn, other_conn, is_server)
|
|
||||||
else:
|
|
||||||
return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)
|
|
||||||
|
|
||||||
def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
|
extensions = []
|
||||||
|
if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
|
||||||
|
if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
|
||||||
|
extensions = [PerMessageDeflate()]
|
||||||
|
self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
|
||||||
|
extensions=extensions)
|
||||||
|
self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
|
||||||
|
host=handshake_flow.request.host,
|
||||||
|
resource=handshake_flow.request.path,
|
||||||
|
extensions=extensions)
|
||||||
|
if extensions:
|
||||||
|
for conn in self.connections.values():
|
||||||
|
conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions'])
|
||||||
|
|
||||||
|
data = self.connections[self.server_conn].bytes_to_send()
|
||||||
|
self.connections[self.client_conn].receive_bytes(data)
|
||||||
|
|
||||||
|
event = next(self.connections[self.client_conn].events())
|
||||||
|
assert isinstance(event, events.ConnectionRequested)
|
||||||
|
|
||||||
|
self.connections[self.client_conn].accept(event)
|
||||||
|
self.connections[self.server_conn].receive_bytes(self.connections[self.client_conn].bytes_to_send())
|
||||||
|
assert isinstance(next(self.connections[self.server_conn].events()), events.ConnectionEstablished)
|
||||||
|
|
||||||
|
def _handle_event(self, event, source_conn, other_conn, is_server):
|
||||||
|
if isinstance(event, events.DataReceived):
|
||||||
|
return self._handle_data_received(event, source_conn, other_conn, is_server)
|
||||||
|
elif isinstance(event, events.PingReceived):
|
||||||
|
return self._handle_ping_received(event, source_conn, other_conn, is_server)
|
||||||
|
elif isinstance(event, events.PongReceived):
|
||||||
|
return self._handle_pong_received(event, source_conn, other_conn, is_server)
|
||||||
|
elif isinstance(event, events.ConnectionClosed):
|
||||||
|
return self._handle_connection_closed(event, source_conn, other_conn, is_server)
|
||||||
|
|
||||||
|
# fail-safe for unhandled events
|
||||||
|
return True # pragma: no cover
|
||||||
|
|
||||||
|
def _handle_data_received(self, event, source_conn, other_conn, is_server):
|
||||||
fb = self.server_frame_buffer if is_server else self.client_frame_buffer
|
fb = self.server_frame_buffer if is_server else self.client_frame_buffer
|
||||||
fb.append(frame)
|
fb.append(event.data)
|
||||||
|
|
||||||
if frame.header.fin:
|
if event.message_finished:
|
||||||
payload = b''.join(f.payload for f in fb)
|
original_chunk_sizes = [len(f) for f in fb]
|
||||||
original_chunk_sizes = [len(f.payload) for f in fb]
|
message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY
|
||||||
message_type = fb[0].header.opcode
|
if message_type == Opcode.TEXT:
|
||||||
compressed_message = fb[0].header.rsv1
|
payload = ''.join(fb)
|
||||||
|
else:
|
||||||
|
payload = b''.join(fb)
|
||||||
fb.clear()
|
fb.clear()
|
||||||
|
|
||||||
websocket_message = WebSocketMessage(message_type, not is_server, payload)
|
websocket_message = WebSocketMessage(message_type, not is_server, payload)
|
||||||
@ -77,7 +111,7 @@ class WebSocketLayer(base.Layer):
|
|||||||
# message has the same length, we can reuse the same sizes
|
# message has the same length, we can reuse the same sizes
|
||||||
pos = 0
|
pos = 0
|
||||||
for s in original_chunk_sizes:
|
for s in original_chunk_sizes:
|
||||||
yield payload[pos:pos + s]
|
yield (payload[pos:pos + s], True if pos + s == length else False)
|
||||||
pos += s
|
pos += s
|
||||||
else:
|
else:
|
||||||
# just re-chunk everything into 4kB frames
|
# just re-chunk everything into 4kB frames
|
||||||
@ -85,95 +119,81 @@ class WebSocketLayer(base.Layer):
|
|||||||
chunk_size = 4092 if is_server else 4088
|
chunk_size = 4092 if is_server else 4088
|
||||||
chunks = range(0, len(payload), chunk_size)
|
chunks = range(0, len(payload), chunk_size)
|
||||||
for i in chunks:
|
for i in chunks:
|
||||||
yield payload[i:i + chunk_size]
|
yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
|
||||||
|
|
||||||
frms = [
|
for chunk, final in get_chunk(websocket_message.content):
|
||||||
websockets.Frame(
|
self.connections[other_conn].send_data(chunk, final)
|
||||||
payload=chunk,
|
other_conn.send(self.connections[other_conn].bytes_to_send())
|
||||||
opcode=frame.header.opcode,
|
|
||||||
mask=(False if is_server else 1),
|
|
||||||
masking_key=(b'' if is_server else os.urandom(4)))
|
|
||||||
for chunk in get_chunk(websocket_message.content)
|
|
||||||
]
|
|
||||||
|
|
||||||
if len(frms) > 0:
|
|
||||||
frms[-1].header.fin = True
|
|
||||||
else:
|
|
||||||
frms.append(websockets.Frame(
|
|
||||||
fin=True,
|
|
||||||
opcode=websockets.OPCODE.CONTINUE,
|
|
||||||
mask=(False if is_server else 1),
|
|
||||||
masking_key=(b'' if is_server else os.urandom(4))))
|
|
||||||
|
|
||||||
frms[0].header.opcode = message_type
|
|
||||||
frms[0].header.rsv1 = compressed_message
|
|
||||||
|
|
||||||
for frm in frms:
|
|
||||||
other_conn.send(bytes(frm))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
other_conn.send(bytes(frame))
|
self.connections[other_conn].send_data(event.data, event.message_finished)
|
||||||
|
other_conn.send(self.connections[other_conn].bytes_to_send())
|
||||||
|
|
||||||
elif self.flow.stream:
|
elif self.flow.stream:
|
||||||
other_conn.send(bytes(frame))
|
self.connections[other_conn].send_data(event.data, event.message_finished)
|
||||||
|
other_conn.send(self.connections[other_conn].bytes_to_send())
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
|
def _handle_ping_received(self, event, source_conn, other_conn, is_server):
|
||||||
# just forward the ping/pong to the other side
|
# PING is automatically answered with a PONG by wsproto
|
||||||
other_conn.send(bytes(frame))
|
self.connections[other_conn].ping()
|
||||||
|
other_conn.send(self.connections[other_conn].bytes_to_send())
|
||||||
|
source_conn.send(self.connections[source_conn].bytes_to_send())
|
||||||
|
self.log(
|
||||||
|
"Ping Received from {}".format("server" if is_server else "client"),
|
||||||
|
"info",
|
||||||
|
[strutils.bytes_to_escaped_str(bytes(event.payload))]
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_close(self, frame, source_conn, other_conn, is_server):
|
def _handle_pong_received(self, event, source_conn, other_conn, is_server):
|
||||||
|
self.log(
|
||||||
|
"Pong Received from {}".format("server" if is_server else "client"),
|
||||||
|
"info",
|
||||||
|
[strutils.bytes_to_escaped_str(bytes(event.payload))]
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
|
||||||
self.flow.close_sender = "server" if is_server else "client"
|
self.flow.close_sender = "server" if is_server else "client"
|
||||||
if len(frame.payload) >= 2:
|
self.flow.close_code = event.code
|
||||||
code, = struct.unpack('!H', frame.payload[:2])
|
self.flow.close_reason = event.reason
|
||||||
self.flow.close_code = code
|
|
||||||
self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
|
|
||||||
if len(frame.payload) > 2:
|
|
||||||
self.flow.close_reason = frame.payload[2:]
|
|
||||||
|
|
||||||
other_conn.send(bytes(frame))
|
self.connections[other_conn].close(event.code, event.reason)
|
||||||
|
other_conn.send(self.connections[other_conn].bytes_to_send())
|
||||||
|
source_conn.send(self.connections[source_conn].bytes_to_send())
|
||||||
|
|
||||||
# initiate close handshake
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
|
|
||||||
# unknown frame - just forward it
|
|
||||||
other_conn.send(bytes(frame))
|
|
||||||
|
|
||||||
sender = "server" if is_server else "client"
|
|
||||||
self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
|
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
|
||||||
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
|
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
|
||||||
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
|
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
|
||||||
self.channel.ask("websocket_start", self.flow)
|
self.channel.ask("websocket_start", self.flow)
|
||||||
|
|
||||||
client = self.client_conn.connection
|
conns = [c.connection for c in self.connections.keys()]
|
||||||
server = self.server_conn.connection
|
|
||||||
conns = [client, server]
|
|
||||||
close_received = False
|
close_received = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not self.channel.should_exit.is_set():
|
while not self.channel.should_exit.is_set():
|
||||||
r = tcp.ssl_read_select(conns, 0.1)
|
r = tcp.ssl_read_select(conns, 0.1)
|
||||||
for conn in r:
|
for conn in r:
|
||||||
source_conn = self.client_conn if conn == client else self.server_conn
|
source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
|
||||||
other_conn = self.server_conn if conn == client else self.client_conn
|
other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
|
||||||
is_server = (conn == self.server_conn.connection)
|
is_server = (source_conn == self.server_conn)
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(source_conn.rfile)
|
frame = websockets.Frame.from_file(source_conn.rfile)
|
||||||
|
self.connections[source_conn].receive_bytes(bytes(frame))
|
||||||
|
source_conn.send(self.connections[source_conn].bytes_to_send())
|
||||||
|
|
||||||
cont = self._handle_frame(frame, source_conn, other_conn, is_server)
|
if close_received:
|
||||||
if not cont:
|
return
|
||||||
if close_received:
|
|
||||||
return
|
for event in self.connections[source_conn].events():
|
||||||
else:
|
if not self._handle_event(event, source_conn, other_conn, is_server):
|
||||||
close_received = True
|
if not close_received:
|
||||||
|
close_received = True
|
||||||
except (socket.error, exceptions.TcpException, SSL.Error) as e:
|
except (socket.error, exceptions.TcpException, SSL.Error) as e:
|
||||||
s = 'server' if is_server else 'client'
|
s = 'server' if is_server else 'client'
|
||||||
self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e)))
|
self.flow.error = flow.Error("WebSocket connection closed unexpectedly by {}: {}".format(s, repr(e)))
|
||||||
|
@ -49,7 +49,7 @@ class UnsupportedLog:
|
|||||||
def websocket_message(self, f):
|
def websocket_message(self, f):
|
||||||
message = f.messages[-1]
|
message = f.messages[-1]
|
||||||
signals.add_log(f.message_info(message), "info")
|
signals.add_log(f.message_info(message), "info")
|
||||||
signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
|
signals.add_log(message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content), "debug")
|
||||||
|
|
||||||
def websocket_end(self, f):
|
def websocket_end(self, f):
|
||||||
signals.add_log("WebSocket connection closed by {}: {} {}, {}".format(
|
signals.add_log("WebSocket connection closed by {}: {} {}, {}".format(
|
||||||
|
@ -21,7 +21,13 @@ exclude_lines =
|
|||||||
|
|
||||||
[tool:full_coverage]
|
[tool:full_coverage]
|
||||||
exclude =
|
exclude =
|
||||||
mitmproxy/proxy/protocol/
|
mitmproxy/proxy/protocol/base.py
|
||||||
|
mitmproxy/proxy/protocol/http.py
|
||||||
|
mitmproxy/proxy/protocol/http1.py
|
||||||
|
mitmproxy/proxy/protocol/http2.py
|
||||||
|
mitmproxy/proxy/protocol/http_replay.py
|
||||||
|
mitmproxy/proxy/protocol/rawtcp.py
|
||||||
|
mitmproxy/proxy/protocol/tls.py
|
||||||
mitmproxy/proxy/root_context.py
|
mitmproxy/proxy/root_context.py
|
||||||
mitmproxy/proxy/server.py
|
mitmproxy/proxy/server.py
|
||||||
mitmproxy/tools/
|
mitmproxy/tools/
|
||||||
@ -64,7 +70,6 @@ exclude =
|
|||||||
mitmproxy/proxy/protocol/http_replay.py
|
mitmproxy/proxy/protocol/http_replay.py
|
||||||
mitmproxy/proxy/protocol/rawtcp.py
|
mitmproxy/proxy/protocol/rawtcp.py
|
||||||
mitmproxy/proxy/protocol/tls.py
|
mitmproxy/proxy/protocol/tls.py
|
||||||
mitmproxy/proxy/protocol/websocket.py
|
|
||||||
mitmproxy/proxy/root_context.py
|
mitmproxy/proxy/root_context.py
|
||||||
mitmproxy/proxy/server.py
|
mitmproxy/proxy/server.py
|
||||||
mitmproxy/stateobject.py
|
mitmproxy/stateobject.py
|
||||||
|
1
setup.py
1
setup.py
@ -65,6 +65,7 @@ setup(
|
|||||||
"certifi>=2015.11.20.1", # no semver here - this should always be on the last release!
|
"certifi>=2015.11.20.1", # no semver here - this should always be on the last release!
|
||||||
"click>=6.2, <7",
|
"click>=6.2, <7",
|
||||||
"cryptography>=2.0,<2.2",
|
"cryptography>=2.0,<2.2",
|
||||||
|
'h11>=0.7.0,<0.8',
|
||||||
"h2>=3.0, <4",
|
"h2>=3.0, <4",
|
||||||
"hyperframe>=5.0, <6",
|
"hyperframe>=5.0, <6",
|
||||||
"kaitaistruct>=0.7, <0.8",
|
"kaitaistruct>=0.7, <0.8",
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import os
|
import os
|
||||||
|
import struct
|
||||||
import tempfile
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@ -33,6 +34,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase):
|
|||||||
connection='upgrade',
|
connection='upgrade',
|
||||||
upgrade='websocket',
|
upgrade='websocket',
|
||||||
sec_websocket_accept=b'',
|
sec_websocket_accept=b'',
|
||||||
|
sec_websocket_extensions='permessage-deflate' if "permessage-deflate" in request.headers.values() else ''
|
||||||
),
|
),
|
||||||
content=b'',
|
content=b'',
|
||||||
)
|
)
|
||||||
@ -80,7 +82,7 @@ class _WebSocketTestBase:
|
|||||||
if self.client:
|
if self.client:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
||||||
def setup_connection(self):
|
def setup_connection(self, extension=False):
|
||||||
self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port))
|
self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port))
|
||||||
self.client.connect()
|
self.client.connect()
|
||||||
|
|
||||||
@ -115,6 +117,7 @@ class _WebSocketTestBase:
|
|||||||
upgrade="websocket",
|
upgrade="websocket",
|
||||||
sec_websocket_version="13",
|
sec_websocket_version="13",
|
||||||
sec_websocket_key="1234",
|
sec_websocket_key="1234",
|
||||||
|
sec_websocket_extensions="permessage-deflate" if extension else ""
|
||||||
),
|
),
|
||||||
content=b'')
|
content=b'')
|
||||||
self.client.wfile.write(http.http1.assemble_request(request))
|
self.client.wfile.write(http.http1.assemble_request(request))
|
||||||
@ -145,11 +148,11 @@ class TestSimple(_WebSocketTest):
|
|||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(rfile)
|
frame = websockets.Frame.from_file(rfile)
|
||||||
wfile.write(bytes(frame))
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(rfile)
|
frame = websockets.Frame.from_file(rfile)
|
||||||
wfile.write(bytes(frame))
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
@pytest.mark.parametrize('streaming', [True, False])
|
@pytest.mark.parametrize('streaming', [True, False])
|
||||||
@ -164,36 +167,59 @@ class TestSimple(_WebSocketTest):
|
|||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'server-foobar'
|
assert frame.payload == b'server-foobar'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'self.client-foobar'
|
assert frame.payload == b'self.client-foobar'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'\xde\xad\xbe\xef'
|
assert frame.payload == b'\xde\xad\xbe\xef'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
assert len(self.master.state.flows) == 2
|
assert len(self.master.state.flows) == 2
|
||||||
assert isinstance(self.master.state.flows[0], HTTPFlow)
|
assert isinstance(self.master.state.flows[0], HTTPFlow)
|
||||||
assert isinstance(self.master.state.flows[1], WebSocketFlow)
|
assert isinstance(self.master.state.flows[1], WebSocketFlow)
|
||||||
assert len(self.master.state.flows[1].messages) == 5
|
assert len(self.master.state.flows[1].messages) == 5
|
||||||
assert self.master.state.flows[1].messages[0].content == b'server-foobar'
|
assert self.master.state.flows[1].messages[0].content == 'server-foobar'
|
||||||
assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
|
assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
|
||||||
assert self.master.state.flows[1].messages[1].content == b'self.client-foobar'
|
assert self.master.state.flows[1].messages[1].content == 'self.client-foobar'
|
||||||
assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
|
assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
|
||||||
assert self.master.state.flows[1].messages[2].content == b'self.client-foobar'
|
assert self.master.state.flows[1].messages[2].content == 'self.client-foobar'
|
||||||
assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
|
assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
|
||||||
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
|
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
|
||||||
assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
|
assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
|
||||||
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
|
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
|
||||||
assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY
|
assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY
|
||||||
|
|
||||||
|
def test_change_payload(self):
|
||||||
|
class Addon:
|
||||||
|
def websocket_message(self, f):
|
||||||
|
f.messages[-1].content = "foo"
|
||||||
|
|
||||||
|
self.master.addons.add(Addon())
|
||||||
|
self.setup_connection()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
|
assert frame.payload == b'foo'
|
||||||
|
|
||||||
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
|
assert frame.payload == b'foo'
|
||||||
|
|
||||||
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
|
assert frame.payload == b'foo'
|
||||||
|
|
||||||
|
|
||||||
class TestSimpleTLS(_WebSocketTest):
|
class TestSimpleTLS(_WebSocketTest):
|
||||||
ssl = True
|
ssl = True
|
||||||
@ -204,7 +230,7 @@ class TestSimpleTLS(_WebSocketTest):
|
|||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(rfile)
|
frame = websockets.Frame.from_file(rfile)
|
||||||
wfile.write(bytes(frame))
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
def test_simple_tls(self):
|
def test_simple_tls(self):
|
||||||
@ -213,13 +239,13 @@ class TestSimpleTLS(_WebSocketTest):
|
|||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'server-foobar'
|
assert frame.payload == b'server-foobar'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'self.client-foobar'
|
assert frame.payload == b'self.client-foobar'
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
|
||||||
@ -234,22 +260,24 @@ class TestPing(_WebSocketTest):
|
|||||||
assert frame.header.opcode == websockets.OPCODE.PONG
|
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||||
assert frame.payload == b'foobar'
|
assert frame.payload == b'foobar'
|
||||||
|
|
||||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done')))
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
|
wfile.flush()
|
||||||
|
websockets.Frame.from_file(rfile)
|
||||||
|
|
||||||
def test_ping(self):
|
def test_ping(self):
|
||||||
self.setup_connection()
|
self.setup_connection()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.header.opcode == websockets.OPCODE.PING
|
websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.payload == b'foobar'
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
assert frame.header.opcode == websockets.OPCODE.PING
|
||||||
|
assert frame.payload == b'' # We don't send payload to other end
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
assert self.master.has_log("Pong Received from server", "info")
|
||||||
assert frame.header.opcode == websockets.OPCODE.TEXT
|
|
||||||
assert frame.payload == b'pong-received'
|
|
||||||
|
|
||||||
|
|
||||||
class TestPong(_WebSocketTest):
|
class TestPong(_WebSocketTest):
|
||||||
@ -258,20 +286,29 @@ class TestPong(_WebSocketTest):
|
|||||||
def handle_websockets(cls, rfile, wfile):
|
def handle_websockets(cls, rfile, wfile):
|
||||||
frame = websockets.Frame.from_file(rfile)
|
frame = websockets.Frame.from_file(rfile)
|
||||||
assert frame.header.opcode == websockets.OPCODE.PING
|
assert frame.header.opcode == websockets.OPCODE.PING
|
||||||
assert frame.payload == b'foobar'
|
assert frame.payload == b''
|
||||||
|
|
||||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
|
wfile.flush()
|
||||||
|
websockets.Frame.from_file(rfile)
|
||||||
|
|
||||||
def test_pong(self):
|
def test_pong(self):
|
||||||
self.setup_connection()
|
self.setup_connection()
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
|
websockets.Frame.from_file(self.client.rfile)
|
||||||
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
assert frame.header.opcode == websockets.OPCODE.PONG
|
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||||
assert frame.payload == b'foobar'
|
assert frame.payload == b'foobar'
|
||||||
|
assert self.master.has_log("Pong Received from server", "info")
|
||||||
|
|
||||||
|
|
||||||
class TestClose(_WebSocketTest):
|
class TestClose(_WebSocketTest):
|
||||||
@ -279,7 +316,7 @@ class TestClose(_WebSocketTest):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def handle_websockets(cls, rfile, wfile):
|
def handle_websockets(cls, rfile, wfile):
|
||||||
frame = websockets.Frame.from_file(rfile)
|
frame = websockets.Frame.from_file(rfile)
|
||||||
wfile.write(bytes(frame))
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
|
||||||
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
wfile.flush()
|
wfile.flush()
|
||||||
|
|
||||||
@ -289,7 +326,7 @@ class TestClose(_WebSocketTest):
|
|||||||
def test_close(self):
|
def test_close(self):
|
||||||
self.setup_connection()
|
self.setup_connection()
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
websockets.Frame.from_file(self.client.rfile)
|
websockets.Frame.from_file(self.client.rfile)
|
||||||
@ -299,7 +336,7 @@ class TestClose(_WebSocketTest):
|
|||||||
def test_close_payload_1(self):
|
def test_close_payload_1(self):
|
||||||
self.setup_connection()
|
self.setup_connection()
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
websockets.Frame.from_file(self.client.rfile)
|
websockets.Frame.from_file(self.client.rfile)
|
||||||
@ -309,7 +346,7 @@ class TestClose(_WebSocketTest):
|
|||||||
def test_close_payload_2(self):
|
def test_close_payload_2(self):
|
||||||
self.setup_connection()
|
self.setup_connection()
|
||||||
|
|
||||||
self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
||||||
self.client.wfile.flush()
|
self.client.wfile.flush()
|
||||||
|
|
||||||
websockets.Frame.from_file(self.client.rfile)
|
websockets.Frame.from_file(self.client.rfile)
|
||||||
@ -329,8 +366,9 @@ class TestInvalidFrame(_WebSocketTest):
|
|||||||
|
|
||||||
# with pytest.raises(exceptions.TcpDisconnect):
|
# with pytest.raises(exceptions.TcpDisconnect):
|
||||||
frame = websockets.Frame.from_file(self.client.rfile)
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
assert frame.header.opcode == 15
|
code, = struct.unpack('!H', frame.payload[:2])
|
||||||
assert frame.payload == b'foobar'
|
assert code == 1002
|
||||||
|
assert frame.payload[2:].startswith(b'Invalid opcode')
|
||||||
|
|
||||||
|
|
||||||
class TestStreaming(_WebSocketTest):
|
class TestStreaming(_WebSocketTest):
|
||||||
@ -360,3 +398,51 @@ class TestStreaming(_WebSocketTest):
|
|||||||
|
|
||||||
assert frame
|
assert frame
|
||||||
assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received
|
assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtension(_WebSocketTest):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_websockets(cls, rfile, wfile):
|
||||||
|
wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00')
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(rfile)
|
||||||
|
assert frame.header.rsv1
|
||||||
|
wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00')
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(rfile)
|
||||||
|
assert frame.header.rsv1
|
||||||
|
wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00')
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
def test_extension(self):
|
||||||
|
self.setup_connection(True)
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
|
assert frame.header.rsv1
|
||||||
|
|
||||||
|
self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v')
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
|
assert frame.header.rsv1
|
||||||
|
|
||||||
|
self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c')
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(self.client.rfile)
|
||||||
|
assert frame.header.rsv1
|
||||||
|
|
||||||
|
assert len(self.master.state.flows[1].messages) == 5
|
||||||
|
assert self.master.state.flows[1].messages[0].content == 'server-foobar'
|
||||||
|
assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
|
||||||
|
assert self.master.state.flows[1].messages[1].content == 'client-foobar'
|
||||||
|
assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
|
||||||
|
assert self.master.state.flows[1].messages[2].content == 'client-foobar'
|
||||||
|
assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
|
||||||
|
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
|
||||||
|
assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
|
||||||
|
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
|
||||||
|
assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY
|
||||||
|
Loading…
Reference in New Issue
Block a user