mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
Merge pull request #1488 from mitmproxy/websockets
add WebSockets support
This commit is contained in:
commit
55d938b880
@ -126,6 +126,18 @@ HTTP Events
|
|||||||
:param HTTPFlow flow: The flow containing the error.
|
:param HTTPFlow flow: The flow containing the error.
|
||||||
It is guaranteed to have non-None ``error`` attribute.
|
It is guaranteed to have non-None ``error`` attribute.
|
||||||
|
|
||||||
|
WebSockets Events
|
||||||
|
^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. py:function:: websockets_handshake(context, flow)
|
||||||
|
|
||||||
|
Called when a client wants to establish a WebSockets connection.
|
||||||
|
The WebSockets-specific headers can be manipulated to manipulate the handshake.
|
||||||
|
The ``flow`` object is guaranteed to have a non-None ``request`` attribute.
|
||||||
|
|
||||||
|
:param HTTPFlow flow: The flow containing the request which has been received.
|
||||||
|
The object is guaranteed to have a non-None ``request`` attribute.
|
||||||
|
|
||||||
TCP Events
|
TCP Events
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
@ -28,6 +28,8 @@ Events = frozenset([
|
|||||||
"response",
|
"response",
|
||||||
"responseheaders",
|
"responseheaders",
|
||||||
|
|
||||||
|
"websockets_handshake",
|
||||||
|
|
||||||
"next_layer",
|
"next_layer",
|
||||||
|
|
||||||
"error",
|
"error",
|
||||||
|
@ -334,6 +334,10 @@ class FlowMaster(controller.Master):
|
|||||||
self.client_playback.clear(f)
|
self.client_playback.clear(f)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
|
def websockets_handshake(self, f):
|
||||||
|
return f
|
||||||
|
|
||||||
def handle_intercept(self, f):
|
def handle_intercept(self, f):
|
||||||
self.state.update_flow(f)
|
self.state.update_flow(f)
|
||||||
|
|
||||||
|
@ -29,8 +29,10 @@ from __future__ import absolute_import, print_function, division
|
|||||||
|
|
||||||
from .base import Layer, ServerConnectionMixin
|
from .base import Layer, ServerConnectionMixin
|
||||||
from .http import UpstreamConnectLayer
|
from .http import UpstreamConnectLayer
|
||||||
|
from .http import HttpLayer
|
||||||
from .http1 import Http1Layer
|
from .http1 import Http1Layer
|
||||||
from .http2 import Http2Layer
|
from .http2 import Http2Layer
|
||||||
|
from .websockets import WebSocketsLayer
|
||||||
from .rawtcp import RawTCPLayer
|
from .rawtcp import RawTCPLayer
|
||||||
from .tls import TlsClientHello
|
from .tls import TlsClientHello
|
||||||
from .tls import TlsLayer
|
from .tls import TlsLayer
|
||||||
@ -40,7 +42,9 @@ __all__ = [
|
|||||||
"Layer", "ServerConnectionMixin",
|
"Layer", "ServerConnectionMixin",
|
||||||
"TlsLayer", "is_tls_record_magic", "TlsClientHello",
|
"TlsLayer", "is_tls_record_magic", "TlsClientHello",
|
||||||
"UpstreamConnectLayer",
|
"UpstreamConnectLayer",
|
||||||
|
"HttpLayer",
|
||||||
"Http1Layer",
|
"Http1Layer",
|
||||||
"Http2Layer",
|
"Http2Layer",
|
||||||
|
"WebSocketsLayer",
|
||||||
"RawTCPLayer",
|
"RawTCPLayer",
|
||||||
]
|
]
|
||||||
|
@ -7,12 +7,14 @@ import traceback
|
|||||||
import h2.exceptions
|
import h2.exceptions
|
||||||
import six
|
import six
|
||||||
|
|
||||||
import netlib.exceptions
|
|
||||||
from mitmproxy import exceptions
|
from mitmproxy import exceptions
|
||||||
from mitmproxy import models
|
from mitmproxy import models
|
||||||
from mitmproxy.protocol import base
|
from mitmproxy.protocol import base
|
||||||
|
|
||||||
|
import netlib.exceptions
|
||||||
from netlib import http
|
from netlib import http
|
||||||
from netlib import tcp
|
from netlib import tcp
|
||||||
|
from netlib import websockets
|
||||||
|
|
||||||
|
|
||||||
class _HttpTransmissionLayer(base.Layer):
|
class _HttpTransmissionLayer(base.Layer):
|
||||||
@ -189,6 +191,11 @@ class HttpLayer(base.Layer):
|
|||||||
self.process_request_hook(flow)
|
self.process_request_hook(flow)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
|
||||||
|
# we only support RFC6455 with WebSockets version 13
|
||||||
|
# allow inline scripts to manupulate the client handshake
|
||||||
|
self.channel.ask("websockets_handshake", flow)
|
||||||
|
|
||||||
if not flow.response:
|
if not flow.response:
|
||||||
self.establish_server_connection(
|
self.establish_server_connection(
|
||||||
flow.request.host,
|
flow.request.host,
|
||||||
@ -212,7 +219,7 @@ class HttpLayer(base.Layer):
|
|||||||
# It may be useful to pass additional args (such as the upgrade header)
|
# It may be useful to pass additional args (such as the upgrade header)
|
||||||
# to next_layer in the future
|
# to next_layer in the future
|
||||||
if flow.response.status_code == 101:
|
if flow.response.status_code == 101:
|
||||||
layer = self.ctx.next_layer(self)
|
layer = self.ctx.next_layer(self, flow)
|
||||||
layer()
|
layer()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
108
mitmproxy/protocol/websockets.py
Normal file
108
mitmproxy/protocol/websockets.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
from __future__ import absolute_import, print_function, division
|
||||||
|
|
||||||
|
import socket
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from OpenSSL import SSL
|
||||||
|
|
||||||
|
from mitmproxy import exceptions
|
||||||
|
from mitmproxy.protocol import base
|
||||||
|
|
||||||
|
import netlib.exceptions
|
||||||
|
from netlib import tcp
|
||||||
|
from netlib import websockets
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketsLayer(base.Layer):
|
||||||
|
"""
|
||||||
|
WebSockets layer to intercept, modify, and forward WebSockets connections
|
||||||
|
|
||||||
|
Only version 13 is supported (as specified in RFC6455)
|
||||||
|
Only HTTP/1.1-initiated connections are supported.
|
||||||
|
|
||||||
|
The client starts by sending an Upgrade-request.
|
||||||
|
In order to determine the handshake and negotiate the correct protocol
|
||||||
|
and extensions, the Upgrade-request is forwarded to the server.
|
||||||
|
The response from the server is then parsed and negotiated settings are extracted.
|
||||||
|
Finally the handshake is completed by forwarding the server-response to the client.
|
||||||
|
After that, only WebSockets frames are exchanged.
|
||||||
|
|
||||||
|
PING/PONG frames pass through and must be answered by the other endpoint.
|
||||||
|
|
||||||
|
CLOSE frames are forwarded before this WebSocketsLayer terminates.
|
||||||
|
|
||||||
|
This layer is transparent to any negotiated extensions.
|
||||||
|
This layer is transparent to any negotiated subprotocols.
|
||||||
|
Only raw frames are forwarded to the other endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ctx, flow):
|
||||||
|
super(WebSocketsLayer, self).__init__(ctx)
|
||||||
|
self._flow = flow
|
||||||
|
|
||||||
|
self.client_key = websockets.get_client_key(self._flow.request.headers)
|
||||||
|
self.client_protocol = websockets.get_protocol(self._flow.request.headers)
|
||||||
|
self.client_extensions = websockets.get_extensions(self._flow.request.headers)
|
||||||
|
|
||||||
|
self.server_accept = websockets.get_server_accept(self._flow.response.headers)
|
||||||
|
self.server_protocol = websockets.get_protocol(self._flow.response.headers)
|
||||||
|
self.server_extensions = websockets.get_extensions(self._flow.response.headers)
|
||||||
|
|
||||||
|
def _handle_frame(self, frame, source_conn, other_conn, is_server):
|
||||||
|
self.log(
|
||||||
|
"WebSockets Frame received from {}".format("server" if is_server else "client"),
|
||||||
|
"debug",
|
||||||
|
[repr(frame)]
|
||||||
|
)
|
||||||
|
|
||||||
|
if frame.header.opcode & 0x8 == 0:
|
||||||
|
# forward the data frame to the other side
|
||||||
|
other_conn.send(bytes(frame))
|
||||||
|
self.log("WebSockets frame received by {}: {}".format(is_server, frame), "debug")
|
||||||
|
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
|
||||||
|
# just forward the ping/pong to the other side
|
||||||
|
other_conn.send(bytes(frame))
|
||||||
|
elif frame.header.opcode == websockets.OPCODE.CLOSE:
|
||||||
|
other_conn.send(bytes(frame))
|
||||||
|
|
||||||
|
code = '(status code missing)'
|
||||||
|
msg = None
|
||||||
|
reason = '(message missing)'
|
||||||
|
if len(frame.payload) >= 2:
|
||||||
|
code, = struct.unpack('!H', frame.payload[:2])
|
||||||
|
msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
|
||||||
|
if len(frame.payload) > 2:
|
||||||
|
reason = frame.payload[2:]
|
||||||
|
self.log("WebSockets connection closed: {} {}, {}".format(code, msg, reason), "info")
|
||||||
|
|
||||||
|
# close the connection
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# unknown frame - just forward it
|
||||||
|
other_conn.send(bytes(frame))
|
||||||
|
|
||||||
|
# continue the connection
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
client = self.client_conn.connection
|
||||||
|
server = self.server_conn.connection
|
||||||
|
conns = [client, server]
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not self.channel.should_exit.is_set():
|
||||||
|
r = tcp.ssl_read_select(conns, 1)
|
||||||
|
for conn in r:
|
||||||
|
source_conn = self.client_conn if conn == client else self.server_conn
|
||||||
|
other_conn = self.server_conn if conn == client else self.client_conn
|
||||||
|
is_server = (conn == self.server_conn.connection)
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(source_conn.rfile)
|
||||||
|
|
||||||
|
if not self._handle_frame(frame, source_conn, other_conn, is_server):
|
||||||
|
return
|
||||||
|
except (socket.error, netlib.exceptions.TcpException, SSL.Error) as e:
|
||||||
|
self.log("WebSockets connection closed unexpectedly by {}: {}".format(
|
||||||
|
"server" if is_server else "client", repr(e)), "info")
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e)))
|
@ -4,6 +4,7 @@ import sys
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from netlib import websockets
|
||||||
import netlib.exceptions
|
import netlib.exceptions
|
||||||
from mitmproxy import exceptions
|
from mitmproxy import exceptions
|
||||||
from mitmproxy import protocol
|
from mitmproxy import protocol
|
||||||
@ -32,7 +33,7 @@ class RootContext(object):
|
|||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def next_layer(self, top_layer):
|
def next_layer(self, top_layer, flow=None):
|
||||||
"""
|
"""
|
||||||
This function determines the next layer in the protocol stack.
|
This function determines the next layer in the protocol stack.
|
||||||
|
|
||||||
@ -42,10 +43,22 @@ class RootContext(object):
|
|||||||
Returns:
|
Returns:
|
||||||
The next layer
|
The next layer
|
||||||
"""
|
"""
|
||||||
layer = self._next_layer(top_layer)
|
layer = self._next_layer(top_layer, flow)
|
||||||
return self.channel.ask("next_layer", layer)
|
return self.channel.ask("next_layer", layer)
|
||||||
|
|
||||||
def _next_layer(self, top_layer):
|
def _next_layer(self, top_layer, flow):
|
||||||
|
if flow is not None:
|
||||||
|
# We already have a flow, try to derive the next information from it
|
||||||
|
|
||||||
|
# Check for WebSockets handshake
|
||||||
|
is_websockets = (
|
||||||
|
flow and
|
||||||
|
websockets.check_handshake(flow.request.headers) and
|
||||||
|
websockets.check_handshake(flow.response.headers)
|
||||||
|
)
|
||||||
|
if isinstance(top_layer, protocol.HttpLayer) and is_websockets:
|
||||||
|
return protocol.WebSocketsLayer(top_layer, flow)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
d = top_layer.client_conn.rfile.peek(3)
|
d = top_layer.client_conn.rfile.peek(3)
|
||||||
except netlib.exceptions.TcpException as e:
|
except netlib.exceptions.TcpException as e:
|
||||||
|
@ -1,11 +1,37 @@
|
|||||||
from __future__ import absolute_import, print_function, division
|
from __future__ import absolute_import, print_function, division
|
||||||
from .frame import FrameHeader, Frame, OPCODE
|
|
||||||
from .protocol import Masker, WebsocketsProtocol
|
from .frame import FrameHeader
|
||||||
|
from .frame import Frame
|
||||||
|
from .frame import OPCODE
|
||||||
|
from .frame import CLOSE_REASON
|
||||||
|
from .masker import Masker
|
||||||
|
from .utils import MAGIC
|
||||||
|
from .utils import VERSION
|
||||||
|
from .utils import client_handshake_headers
|
||||||
|
from .utils import server_handshake_headers
|
||||||
|
from .utils import check_handshake
|
||||||
|
from .utils import check_client_version
|
||||||
|
from .utils import create_server_nonce
|
||||||
|
from .utils import get_extensions
|
||||||
|
from .utils import get_protocol
|
||||||
|
from .utils import get_client_key
|
||||||
|
from .utils import get_server_accept
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FrameHeader",
|
"FrameHeader",
|
||||||
"Frame",
|
"Frame",
|
||||||
"Masker",
|
|
||||||
"WebsocketsProtocol",
|
|
||||||
"OPCODE",
|
"OPCODE",
|
||||||
|
"CLOSE_REASON",
|
||||||
|
"Masker",
|
||||||
|
"MAGIC",
|
||||||
|
"VERSION",
|
||||||
|
"client_handshake_headers",
|
||||||
|
"server_handshake_headers",
|
||||||
|
"check_handshake",
|
||||||
|
"check_client_version",
|
||||||
|
"create_server_nonce",
|
||||||
|
"get_extensions",
|
||||||
|
"get_protocol",
|
||||||
|
"get_client_key",
|
||||||
|
"get_server_accept",
|
||||||
]
|
]
|
||||||
|
@ -2,7 +2,6 @@ from __future__ import absolute_import
|
|||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import io
|
import io
|
||||||
import warnings
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@ -10,7 +9,7 @@ from netlib import tcp
|
|||||||
from netlib import strutils
|
from netlib import strutils
|
||||||
from netlib import utils
|
from netlib import utils
|
||||||
from netlib import human
|
from netlib import human
|
||||||
from netlib.websockets import protocol
|
from .masker import Masker
|
||||||
|
|
||||||
|
|
||||||
MAX_16_BIT_INT = (1 << 16)
|
MAX_16_BIT_INT = (1 << 16)
|
||||||
@ -18,6 +17,7 @@ MAX_64_BIT_INT = (1 << 64)
|
|||||||
|
|
||||||
DEFAULT = object()
|
DEFAULT = object()
|
||||||
|
|
||||||
|
# RFC 6455, Section 5.2 - Base Framing Protocol
|
||||||
OPCODE = utils.BiDi(
|
OPCODE = utils.BiDi(
|
||||||
CONTINUE=0x00,
|
CONTINUE=0x00,
|
||||||
TEXT=0x01,
|
TEXT=0x01,
|
||||||
@ -27,6 +27,23 @@ OPCODE = utils.BiDi(
|
|||||||
PONG=0x0a
|
PONG=0x0a
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# RFC 6455, Section 7.4.1 - Defined Status Codes
|
||||||
|
CLOSE_REASON = utils.BiDi(
|
||||||
|
NORMAL_CLOSURE=1000,
|
||||||
|
GOING_AWAY=1001,
|
||||||
|
PROTOCOL_ERROR=1002,
|
||||||
|
UNSUPPORTED_DATA=1003,
|
||||||
|
RESERVED=1004,
|
||||||
|
RESERVED_NO_STATUS=1005,
|
||||||
|
RESERVED_ABNORMAL_CLOSURE=1006,
|
||||||
|
INVALID_PAYLOAD_DATA=1007,
|
||||||
|
POLICY_VIOLATION=1008,
|
||||||
|
MESSAGE_TOO_BIG=1009,
|
||||||
|
MANDATORY_EXTENSION=1010,
|
||||||
|
INTERNAL_ERROR=1011,
|
||||||
|
RESERVED_TLS_HANDHSAKE_FAILED=1015,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FrameHeader(object):
|
class FrameHeader(object):
|
||||||
|
|
||||||
@ -103,10 +120,6 @@ class FrameHeader(object):
|
|||||||
vals.append(" %s" % human.pretty_size(self.payload_length))
|
vals.append(" %s" % human.pretty_size(self.payload_length))
|
||||||
return "".join(vals)
|
return "".join(vals)
|
||||||
|
|
||||||
def human_readable(self):
|
|
||||||
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
|
|
||||||
return repr(self)
|
|
||||||
|
|
||||||
def __bytes__(self):
|
def __bytes__(self):
|
||||||
first_byte = utils.setbit(0, 7, self.fin)
|
first_byte = utils.setbit(0, 7, self.fin)
|
||||||
first_byte = utils.setbit(first_byte, 6, self.rsv1)
|
first_byte = utils.setbit(first_byte, 6, self.rsv1)
|
||||||
@ -128,6 +141,9 @@ class FrameHeader(object):
|
|||||||
# '!Q' = pack as 64 bit unsigned long long
|
# '!Q' = pack as 64 bit unsigned long long
|
||||||
# add 8 bytes extended payload length
|
# add 8 bytes extended payload length
|
||||||
b += struct.pack('!Q', self.payload_length)
|
b += struct.pack('!Q', self.payload_length)
|
||||||
|
else:
|
||||||
|
raise ValueError("Payload length exceeds 64bit integer")
|
||||||
|
|
||||||
if self.masking_key:
|
if self.masking_key:
|
||||||
b += self.masking_key
|
b += self.masking_key
|
||||||
return b
|
return b
|
||||||
@ -135,10 +151,6 @@ class FrameHeader(object):
|
|||||||
if six.PY2:
|
if six.PY2:
|
||||||
__str__ = __bytes__
|
__str__ = __bytes__
|
||||||
|
|
||||||
def to_bytes(self):
|
|
||||||
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
|
|
||||||
return bytes(self)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(cls, fp):
|
def from_file(cls, fp):
|
||||||
"""
|
"""
|
||||||
@ -151,19 +163,17 @@ class FrameHeader(object):
|
|||||||
rsv1 = utils.getbit(first_byte, 6)
|
rsv1 = utils.getbit(first_byte, 6)
|
||||||
rsv2 = utils.getbit(first_byte, 5)
|
rsv2 = utils.getbit(first_byte, 5)
|
||||||
rsv3 = utils.getbit(first_byte, 4)
|
rsv3 = utils.getbit(first_byte, 4)
|
||||||
# grab right-most 4 bits
|
opcode = first_byte & 0xF
|
||||||
opcode = first_byte & 15
|
|
||||||
mask_bit = utils.getbit(second_byte, 7)
|
mask_bit = utils.getbit(second_byte, 7)
|
||||||
# grab the next 7 bits
|
length_code = second_byte & 0x7F
|
||||||
length_code = second_byte & 127
|
|
||||||
|
|
||||||
# payload_lengthy > 125 indicates you need to read more bytes
|
# payload_length > 125 indicates you need to read more bytes
|
||||||
# to get the actual payload length
|
# to get the actual payload length
|
||||||
if length_code <= 125:
|
if length_code <= 125:
|
||||||
payload_length = length_code
|
payload_length = length_code
|
||||||
elif length_code == 126:
|
elif length_code == 126:
|
||||||
payload_length, = struct.unpack("!H", fp.safe_read(2))
|
payload_length, = struct.unpack("!H", fp.safe_read(2))
|
||||||
elif length_code == 127:
|
else: # length_code == 127:
|
||||||
payload_length, = struct.unpack("!Q", fp.safe_read(8))
|
payload_length, = struct.unpack("!Q", fp.safe_read(8))
|
||||||
|
|
||||||
# masking key only present if mask bit set
|
# masking key only present if mask bit set
|
||||||
@ -191,11 +201,10 @@ class FrameHeader(object):
|
|||||||
|
|
||||||
|
|
||||||
class Frame(object):
|
class Frame(object):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Represents one websockets frame.
|
Represents a single WebSockets frame.
|
||||||
Constructor takes human readable forms of the frame components
|
Constructor takes human readable forms of the frame components.
|
||||||
from_bytes() is also avaliable.
|
from_bytes() reads from a file-like object to create a new Frame.
|
||||||
|
|
||||||
WebSockets Frame as defined in RFC6455
|
WebSockets Frame as defined in RFC6455
|
||||||
|
|
||||||
@ -223,27 +232,6 @@ class Frame(object):
|
|||||||
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
|
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
|
||||||
self.header = FrameHeader(**kwargs)
|
self.header = FrameHeader(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def default(cls, message, from_client=False):
|
|
||||||
"""
|
|
||||||
Construct a basic websocket frame from some default values.
|
|
||||||
Creates a non-fragmented text frame.
|
|
||||||
"""
|
|
||||||
if from_client:
|
|
||||||
mask_bit = 1
|
|
||||||
masking_key = os.urandom(4)
|
|
||||||
else:
|
|
||||||
mask_bit = 0
|
|
||||||
masking_key = None
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
message,
|
|
||||||
fin=1, # final frame
|
|
||||||
opcode=OPCODE.TEXT, # text
|
|
||||||
mask=mask_bit,
|
|
||||||
masking_key=masking_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, bytestring):
|
def from_bytes(cls, bytestring):
|
||||||
"""
|
"""
|
||||||
@ -258,17 +246,13 @@ class Frame(object):
|
|||||||
ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload)
|
ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def human_readable(self):
|
|
||||||
warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning)
|
|
||||||
return repr(self)
|
|
||||||
|
|
||||||
def __bytes__(self):
|
def __bytes__(self):
|
||||||
"""
|
"""
|
||||||
Serialize the frame to wire format. Returns a string.
|
Serialize the frame to wire format. Returns a string.
|
||||||
"""
|
"""
|
||||||
b = bytes(self.header)
|
b = bytes(self.header)
|
||||||
if self.header.masking_key:
|
if self.header.masking_key:
|
||||||
b += protocol.Masker(self.header.masking_key)(self.payload)
|
b += Masker(self.header.masking_key)(self.payload)
|
||||||
else:
|
else:
|
||||||
b += self.payload
|
b += self.payload
|
||||||
return b
|
return b
|
||||||
@ -276,15 +260,6 @@ class Frame(object):
|
|||||||
if six.PY2:
|
if six.PY2:
|
||||||
__str__ = __bytes__
|
__str__ = __bytes__
|
||||||
|
|
||||||
def to_bytes(self):
|
|
||||||
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
|
|
||||||
return bytes(self)
|
|
||||||
|
|
||||||
def to_file(self, writer):
|
|
||||||
warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning)
|
|
||||||
writer.write(bytes(self))
|
|
||||||
writer.flush()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(cls, fp):
|
def from_file(cls, fp):
|
||||||
"""
|
"""
|
||||||
@ -297,20 +272,11 @@ class Frame(object):
|
|||||||
payload = fp.safe_read(header.payload_length)
|
payload = fp.safe_read(header.payload_length)
|
||||||
|
|
||||||
if header.mask == 1 and header.masking_key:
|
if header.mask == 1 and header.masking_key:
|
||||||
payload = protocol.Masker(header.masking_key)(payload)
|
payload = Masker(header.masking_key)(payload)
|
||||||
|
|
||||||
return cls(
|
frame = cls(payload)
|
||||||
payload,
|
frame.header = header
|
||||||
fin=header.fin,
|
return frame
|
||||||
opcode=header.opcode,
|
|
||||||
mask=header.mask,
|
|
||||||
payload_length=header.payload_length,
|
|
||||||
masking_key=header.masking_key,
|
|
||||||
rsv1=header.rsv1,
|
|
||||||
rsv2=header.rsv2,
|
|
||||||
rsv3=header.rsv3,
|
|
||||||
length_code=header.length_code
|
|
||||||
)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, Frame):
|
if isinstance(other, Frame):
|
||||||
|
33
netlib/websockets/masker.py
Normal file
33
netlib/websockets/masker.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
|
class Masker(object):
|
||||||
|
"""
|
||||||
|
Data sent from the server must be masked to prevent malicious clients
|
||||||
|
from sending data over the wire in predictable patterns.
|
||||||
|
|
||||||
|
Servers do not have to mask data they send to the client.
|
||||||
|
https://tools.ietf.org/html/rfc6455#section-5.3
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key):
|
||||||
|
self.key = key
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
|
def mask(self, offset, data):
|
||||||
|
result = bytearray(data)
|
||||||
|
for i in range(len(data)):
|
||||||
|
if six.PY2:
|
||||||
|
result[i] ^= ord(self.key[offset % 4])
|
||||||
|
else:
|
||||||
|
result[i] ^= self.key[offset % 4]
|
||||||
|
offset += 1
|
||||||
|
result = bytes(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
ret = self.mask(self.offset, data)
|
||||||
|
self.offset += len(ret)
|
||||||
|
return ret
|
@ -1,112 +0,0 @@
|
|||||||
"""
|
|
||||||
Colleciton of utility functions that implement small portions of the RFC6455
|
|
||||||
WebSockets Protocol Useful for building WebSocket clients and servers.
|
|
||||||
|
|
||||||
Emphassis is on readabilty, simplicity and modularity, not performance or
|
|
||||||
completeness
|
|
||||||
|
|
||||||
This is a work in progress and does not yet contain all the utilites need to
|
|
||||||
create fully complient client/servers #
|
|
||||||
Spec: https://tools.ietf.org/html/rfc6455
|
|
||||||
|
|
||||||
The magic sha that websocket servers must know to prove they understand
|
|
||||||
RFC6455
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
|
|
||||||
import six
|
|
||||||
|
|
||||||
from netlib import http, strutils
|
|
||||||
|
|
||||||
websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
|
||||||
VERSION = "13"
|
|
||||||
|
|
||||||
|
|
||||||
class Masker(object):
|
|
||||||
|
|
||||||
"""
|
|
||||||
Data sent from the server must be masked to prevent malicious clients
|
|
||||||
from sending data over the wire in predictable patterns
|
|
||||||
|
|
||||||
Servers do not have to mask data they send to the client.
|
|
||||||
https://tools.ietf.org/html/rfc6455#section-5.3
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, key):
|
|
||||||
self.key = key
|
|
||||||
self.offset = 0
|
|
||||||
|
|
||||||
def mask(self, offset, data):
|
|
||||||
result = bytearray(data)
|
|
||||||
if six.PY2:
|
|
||||||
for i in range(len(data)):
|
|
||||||
result[i] ^= ord(self.key[offset % 4])
|
|
||||||
offset += 1
|
|
||||||
result = str(result)
|
|
||||||
else:
|
|
||||||
|
|
||||||
for i in range(len(data)):
|
|
||||||
result[i] ^= self.key[offset % 4]
|
|
||||||
offset += 1
|
|
||||||
result = bytes(result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __call__(self, data):
|
|
||||||
ret = self.mask(self.offset, data)
|
|
||||||
self.offset += len(ret)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
class WebsocketsProtocol(object):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def client_handshake_headers(self, key=None, version=VERSION):
|
|
||||||
"""
|
|
||||||
Create the headers for a valid HTTP upgrade request. If Key is not
|
|
||||||
specified, it is generated, and can be found in sec-websocket-key in
|
|
||||||
the returned header set.
|
|
||||||
|
|
||||||
Returns an instance of http.Headers
|
|
||||||
"""
|
|
||||||
if not key:
|
|
||||||
key = base64.b64encode(os.urandom(16)).decode('ascii')
|
|
||||||
return http.Headers(
|
|
||||||
sec_websocket_key=key,
|
|
||||||
sec_websocket_version=version,
|
|
||||||
connection="Upgrade",
|
|
||||||
upgrade="websocket",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def server_handshake_headers(self, key):
|
|
||||||
"""
|
|
||||||
The server response is a valid HTTP 101 response.
|
|
||||||
"""
|
|
||||||
return http.Headers(
|
|
||||||
sec_websocket_accept=self.create_server_nonce(key),
|
|
||||||
connection="Upgrade",
|
|
||||||
upgrade="websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def check_client_handshake(self, headers):
|
|
||||||
if headers.get("upgrade") != "websocket":
|
|
||||||
return
|
|
||||||
return headers.get("sec-websocket-key")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def check_server_handshake(self, headers):
|
|
||||||
if headers.get("upgrade") != "websocket":
|
|
||||||
return
|
|
||||||
return headers.get("sec-websocket-accept")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_server_nonce(self, client_nonce):
|
|
||||||
return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + websockets_magic).digest())
|
|
90
netlib/websockets/utils.py
Normal file
90
netlib/websockets/utils.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""
|
||||||
|
Collection of WebSockets Protocol utility functions (RFC6455)
|
||||||
|
Spec: https://tools.ietf.org/html/rfc6455
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
from netlib import http, strutils
|
||||||
|
|
||||||
|
MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||||
|
VERSION = "13"
|
||||||
|
|
||||||
|
|
||||||
|
def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):
|
||||||
|
"""
|
||||||
|
Create the headers for a valid HTTP upgrade request. If Key is not
|
||||||
|
specified, it is generated, and can be found in sec-websocket-key in
|
||||||
|
the returned header set.
|
||||||
|
|
||||||
|
Returns an instance of http.Headers
|
||||||
|
"""
|
||||||
|
if version is None:
|
||||||
|
version = VERSION
|
||||||
|
if key is None:
|
||||||
|
key = base64.b64encode(os.urandom(16)).decode('ascii')
|
||||||
|
h = http.Headers(
|
||||||
|
connection="upgrade",
|
||||||
|
upgrade="websocket",
|
||||||
|
sec_websocket_version=version,
|
||||||
|
sec_websocket_key=key,
|
||||||
|
)
|
||||||
|
if protocol is not None:
|
||||||
|
h['sec-websocket-protocol'] = protocol
|
||||||
|
if extensions is not None:
|
||||||
|
h['sec-websocket-extensions'] = extensions
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def server_handshake_headers(client_key, protocol=None, extensions=None):
|
||||||
|
"""
|
||||||
|
The server response is a valid HTTP 101 response.
|
||||||
|
|
||||||
|
Returns an instance of http.Headers
|
||||||
|
"""
|
||||||
|
h = http.Headers(
|
||||||
|
connection="upgrade",
|
||||||
|
upgrade="websocket",
|
||||||
|
sec_websocket_accept=create_server_nonce(client_key),
|
||||||
|
)
|
||||||
|
if protocol is not None:
|
||||||
|
h['sec-websocket-protocol'] = protocol
|
||||||
|
if extensions is not None:
|
||||||
|
h['sec-websocket-extensions'] = extensions
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def check_handshake(headers):
|
||||||
|
return (
|
||||||
|
"upgrade" in headers.get("connection", "").lower() and
|
||||||
|
headers.get("upgrade", "").lower() == "websocket" and
|
||||||
|
(headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_server_nonce(client_nonce):
|
||||||
|
return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest())
|
||||||
|
|
||||||
|
|
||||||
|
def check_client_version(headers):
|
||||||
|
return headers.get("sec-websocket-version", "") == VERSION
|
||||||
|
|
||||||
|
|
||||||
|
def get_extensions(headers):
|
||||||
|
return headers.get("sec-websocket-extensions", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_protocol(headers):
|
||||||
|
return headers.get("sec-websocket-protocol", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_key(headers):
|
||||||
|
return headers.get("sec-websocket-key", None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_server_accept(headers):
|
||||||
|
return headers.get("sec-websocket-accept", None)
|
@ -198,7 +198,7 @@ class Response(_HTTPMessage):
|
|||||||
1,
|
1,
|
||||||
StatusCode(101)
|
StatusCode(101)
|
||||||
)
|
)
|
||||||
headers = netlib.websockets.WebsocketsProtocol.server_handshake_headers(
|
headers = netlib.websockets.server_handshake_headers(
|
||||||
settings.websocket_key
|
settings.websocket_key
|
||||||
)
|
)
|
||||||
for i in headers.fields:
|
for i in headers.fields:
|
||||||
@ -310,7 +310,7 @@ class Request(_HTTPMessage):
|
|||||||
1,
|
1,
|
||||||
Method("get")
|
Method("get")
|
||||||
)
|
)
|
||||||
for i in netlib.websockets.WebsocketsProtocol.client_handshake_headers().fields:
|
for i in netlib.websockets.client_handshake_headers().fields:
|
||||||
if not get_header(i[0], self.headers):
|
if not get_header(i[0], self.headers):
|
||||||
tokens.append(
|
tokens.append(
|
||||||
Header(
|
Header(
|
||||||
|
@ -139,7 +139,7 @@ class WebsocketFrameReader(basethread.BaseThread):
|
|||||||
except exceptions.TcpDisconnect:
|
except exceptions.TcpDisconnect:
|
||||||
return
|
return
|
||||||
self.frames_queue.put(frm)
|
self.frames_queue.put(frm)
|
||||||
log("<< %s" % frm.header.human_readable())
|
log("<< %s" % repr(frm.header))
|
||||||
if self.ws_read_limit is not None:
|
if self.ws_read_limit is not None:
|
||||||
self.ws_read_limit -= 1
|
self.ws_read_limit -= 1
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
|
@ -173,12 +173,13 @@ class PathodHandler(tcp.BaseHandler):
|
|||||||
retlog["cipher"] = self.get_current_cipher()
|
retlog["cipher"] = self.get_current_cipher()
|
||||||
|
|
||||||
m = utils.MemBool()
|
m = utils.MemBool()
|
||||||
websocket_key = websockets.WebsocketsProtocol.check_client_handshake(headers)
|
|
||||||
self.settings.websocket_key = websocket_key
|
valid_websockets_handshake = websockets.check_handshake(headers)
|
||||||
|
self.settings.websocket_key = websockets.get_client_key(headers)
|
||||||
|
|
||||||
# If this is a websocket initiation, we respond with a proper
|
# If this is a websocket initiation, we respond with a proper
|
||||||
# server response, unless over-ridden.
|
# server response, unless over-ridden.
|
||||||
if websocket_key:
|
if valid_websockets_handshake:
|
||||||
anchor_gen = language.parse_pathod("ws")
|
anchor_gen = language.parse_pathod("ws")
|
||||||
else:
|
else:
|
||||||
anchor_gen = None
|
anchor_gen = None
|
||||||
@ -225,7 +226,7 @@ class PathodHandler(tcp.BaseHandler):
|
|||||||
spec,
|
spec,
|
||||||
lg
|
lg
|
||||||
)
|
)
|
||||||
if nexthandler and websocket_key:
|
if nexthandler and valid_websockets_handshake:
|
||||||
self.protocol = protocols.websockets.WebsocketsProtocol(self)
|
self.protocol = protocols.websockets.WebsocketsProtocol(self)
|
||||||
return self.protocol.handle_websocket, retlog
|
return self.protocol.handle_websocket, retlog
|
||||||
else:
|
else:
|
||||||
|
@ -20,7 +20,7 @@ class WebsocketsProtocol:
|
|||||||
lg("Error reading websocket frame: %s" % e)
|
lg("Error reading websocket frame: %s" % e)
|
||||||
return None, None
|
return None, None
|
||||||
ended = time.time()
|
ended = time.time()
|
||||||
lg(frm.human_readable())
|
lg(repr(frm))
|
||||||
retlog = dict(
|
retlog = dict(
|
||||||
type="inbound",
|
type="inbound",
|
||||||
protocol="websockets",
|
protocol="websockets",
|
||||||
|
0
test/mitmproxy/protocol/__init__.py
Normal file
0
test/mitmproxy/protocol/__init__.py
Normal file
@ -1,7 +1,9 @@
|
|||||||
|
from __future__ import (absolute_import, print_function, division)
|
||||||
|
|
||||||
from netlib.http import http1
|
from netlib.http import http1
|
||||||
from netlib.tcp import TCPClient
|
from netlib.tcp import TCPClient
|
||||||
from netlib.tutils import treq
|
from netlib.tutils import treq
|
||||||
from . import tutils, tservers
|
from .. import tutils, tservers
|
||||||
|
|
||||||
|
|
||||||
class TestHTTPFlow(object):
|
class TestHTTPFlow(object):
|
@ -13,11 +13,11 @@ from mitmproxy import options
|
|||||||
from mitmproxy.proxy.config import ProxyConfig
|
from mitmproxy.proxy.config import ProxyConfig
|
||||||
|
|
||||||
import netlib
|
import netlib
|
||||||
from ..netlib import tservers as netlib_tservers
|
from ...netlib import tservers as netlib_tservers
|
||||||
from netlib.exceptions import HttpException
|
from netlib.exceptions import HttpException
|
||||||
from netlib.http.http2 import framereader
|
from netlib.http.http2 import framereader
|
||||||
|
|
||||||
from . import tservers
|
from .. import tservers
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING)
|
logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING)
|
299
test/mitmproxy/protocol/test_websockets.py
Normal file
299
test/mitmproxy/protocol/test_websockets.py
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
from __future__ import absolute_import, print_function, division
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from mitmproxy import options
|
||||||
|
from mitmproxy.proxy.config import ProxyConfig
|
||||||
|
|
||||||
|
import netlib
|
||||||
|
from netlib import http
|
||||||
|
from ...netlib import tservers as netlib_tservers
|
||||||
|
from .. import tservers
|
||||||
|
|
||||||
|
from netlib import websockets
|
||||||
|
|
||||||
|
|
||||||
|
class _WebSocketsServerBase(netlib_tservers.ServerTestBase):
|
||||||
|
|
||||||
|
class handler(netlib.tcp.BaseHandler):
|
||||||
|
|
||||||
|
def handle(self):
|
||||||
|
try:
|
||||||
|
request = http.http1.read_request(self.rfile)
|
||||||
|
assert websockets.check_handshake(request.headers)
|
||||||
|
|
||||||
|
response = http.Response(
|
||||||
|
"HTTP/1.1",
|
||||||
|
101,
|
||||||
|
reason=http.status_codes.RESPONSES.get(101),
|
||||||
|
headers=http.Headers(
|
||||||
|
connection='upgrade',
|
||||||
|
upgrade='websocket',
|
||||||
|
sec_websocket_accept=b'',
|
||||||
|
),
|
||||||
|
content=b'',
|
||||||
|
)
|
||||||
|
self.wfile.write(http.http1.assemble_response(response))
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
|
self.server.handle_websockets(self.rfile, self.wfile)
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
class _WebSocketsTestBase(object):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
opts = cls.get_options()
|
||||||
|
cls.config = ProxyConfig(opts)
|
||||||
|
|
||||||
|
tmaster = tservers.TestMaster(opts, cls.config)
|
||||||
|
tmaster.start_app(options.APP_HOST, options.APP_PORT)
|
||||||
|
cls.proxy = tservers.ProxyThread(tmaster)
|
||||||
|
cls.proxy.start()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
cls.proxy.shutdown()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_options(cls):
|
||||||
|
opts = options.Options(
|
||||||
|
listen_port=0,
|
||||||
|
no_upstream_cert=False,
|
||||||
|
ssl_insecure=True
|
||||||
|
)
|
||||||
|
opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
|
||||||
|
return opts
|
||||||
|
|
||||||
|
@property
|
||||||
|
def master(self):
|
||||||
|
return self.proxy.tmaster
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
self.master.clear_log()
|
||||||
|
self.master.state.clear()
|
||||||
|
self.server.server.handle_websockets = self.handle_websockets
|
||||||
|
|
||||||
|
def _setup_connection(self):
|
||||||
|
client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port))
|
||||||
|
client.connect()
|
||||||
|
|
||||||
|
request = http.Request(
|
||||||
|
"authority",
|
||||||
|
"CONNECT",
|
||||||
|
"",
|
||||||
|
"localhost",
|
||||||
|
self.server.server.address.port,
|
||||||
|
"",
|
||||||
|
"HTTP/1.1",
|
||||||
|
content=b'')
|
||||||
|
client.wfile.write(http.http1.assemble_request(request))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
response = http.http1.read_response(client.rfile, request)
|
||||||
|
|
||||||
|
if self.ssl:
|
||||||
|
client.convert_to_ssl()
|
||||||
|
assert client.ssl_established
|
||||||
|
|
||||||
|
request = http.Request(
|
||||||
|
"relative",
|
||||||
|
"GET",
|
||||||
|
"http",
|
||||||
|
"localhost",
|
||||||
|
self.server.server.address.port,
|
||||||
|
"/ws",
|
||||||
|
"HTTP/1.1",
|
||||||
|
headers=http.Headers(
|
||||||
|
connection="upgrade",
|
||||||
|
upgrade="websocket",
|
||||||
|
sec_websocket_version="13",
|
||||||
|
sec_websocket_key="1234",
|
||||||
|
),
|
||||||
|
content=b'')
|
||||||
|
client.wfile.write(http.http1.assemble_request(request))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
response = http.http1.read_response(client.rfile, request)
|
||||||
|
assert websockets.check_handshake(response.headers)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setup_class(cls):
|
||||||
|
_WebSocketsTestBase.setup_class()
|
||||||
|
_WebSocketsServerBase.setup_class(ssl=cls.ssl)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def teardown_class(cls):
|
||||||
|
_WebSocketsTestBase.teardown_class()
|
||||||
|
_WebSocketsServerBase.teardown_class()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimple(_WebSocketsTest):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_websockets(cls, rfile, wfile):
|
||||||
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(rfile)
|
||||||
|
wfile.write(bytes(frame))
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
client = self._setup_connection()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(client.rfile)
|
||||||
|
assert frame.payload == b'server-foobar'
|
||||||
|
|
||||||
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(client.rfile)
|
||||||
|
assert frame.payload == b'client-foobar'
|
||||||
|
|
||||||
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimpleTLS(_WebSocketsTest):
|
||||||
|
ssl = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_websockets(cls, rfile, wfile):
|
||||||
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(rfile)
|
||||||
|
wfile.write(bytes(frame))
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
def test_simple_tls(self):
|
||||||
|
client = self._setup_connection()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(client.rfile)
|
||||||
|
assert frame.payload == b'server-foobar'
|
||||||
|
|
||||||
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(client.rfile)
|
||||||
|
assert frame.payload == b'client-foobar'
|
||||||
|
|
||||||
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPing(_WebSocketsTest):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_websockets(cls, rfile, wfile):
|
||||||
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(rfile)
|
||||||
|
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||||
|
assert frame.payload == b'foobar'
|
||||||
|
|
||||||
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
def test_ping(self):
|
||||||
|
client = self._setup_connection()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(client.rfile)
|
||||||
|
assert frame.header.opcode == websockets.OPCODE.PING
|
||||||
|
assert frame.payload == b'foobar'
|
||||||
|
|
||||||
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(client.rfile)
|
||||||
|
assert frame.header.opcode == websockets.OPCODE.TEXT
|
||||||
|
assert frame.payload == b'pong-received'
|
||||||
|
|
||||||
|
|
||||||
|
class TestPong(_WebSocketsTest):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_websockets(cls, rfile, wfile):
|
||||||
|
frame = websockets.Frame.from_file(rfile)
|
||||||
|
assert frame.header.opcode == websockets.OPCODE.PING
|
||||||
|
assert frame.payload == b'foobar'
|
||||||
|
|
||||||
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
def test_pong(self):
|
||||||
|
client = self._setup_connection()
|
||||||
|
|
||||||
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
frame = websockets.Frame.from_file(client.rfile)
|
||||||
|
assert frame.header.opcode == websockets.OPCODE.PONG
|
||||||
|
assert frame.payload == b'foobar'
|
||||||
|
|
||||||
|
|
||||||
|
class TestClose(_WebSocketsTest):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_websockets(cls, rfile, wfile):
|
||||||
|
frame = websockets.Frame.from_file(rfile)
|
||||||
|
wfile.write(bytes(frame))
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||||
|
websockets.Frame.from_file(rfile)
|
||||||
|
|
||||||
|
def test_close(self):
|
||||||
|
client = self._setup_connection()
|
||||||
|
|
||||||
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||||
|
websockets.Frame.from_file(client.rfile)
|
||||||
|
|
||||||
|
def test_close_payload_1(self):
|
||||||
|
client = self._setup_connection()
|
||||||
|
|
||||||
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||||
|
websockets.Frame.from_file(client.rfile)
|
||||||
|
|
||||||
|
def test_close_payload_2(self):
|
||||||
|
client = self._setup_connection()
|
||||||
|
|
||||||
|
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
|
||||||
|
client.wfile.flush()
|
||||||
|
|
||||||
|
with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||||
|
websockets.Frame.from_file(client.rfile)
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvalidFrame(_WebSocketsTest):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_websockets(cls, rfile, wfile):
|
||||||
|
wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar')))
|
||||||
|
wfile.flush()
|
||||||
|
|
||||||
|
def test_invalid_frame(self):
|
||||||
|
client = self._setup_connection()
|
||||||
|
|
||||||
|
# with pytest.raises(netlib.exceptions.TcpDisconnect):
|
||||||
|
frame = websockets.Frame.from_file(client.rfile)
|
||||||
|
assert frame.header.opcode == 15
|
||||||
|
assert frame.payload == b'foobar'
|
164
test/netlib/websockets/test_frame.py
Normal file
164
test/netlib/websockets/test_frame.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import os
|
||||||
|
import codecs
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from netlib import websockets
|
||||||
|
from netlib import tutils
|
||||||
|
|
||||||
|
|
||||||
|
class TestFrameHeader(object):
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
(0, '0100'),
|
||||||
|
(125, '017D'),
|
||||||
|
(126, '017E007E'),
|
||||||
|
(127, '017E007F'),
|
||||||
|
(142, '017E008E'),
|
||||||
|
(65534, '017EFFFE'),
|
||||||
|
(65535, '017EFFFF'),
|
||||||
|
(65536, '017F0000000000010000'),
|
||||||
|
(8589934591, '017F00000001FFFFFFFF'),
|
||||||
|
(2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'),
|
||||||
|
])
|
||||||
|
def test_serialization_length(self, input, expected):
|
||||||
|
h = websockets.FrameHeader(
|
||||||
|
opcode=websockets.OPCODE.TEXT,
|
||||||
|
payload_length=input,
|
||||||
|
)
|
||||||
|
assert bytes(h) == codecs.decode(expected, 'hex')
|
||||||
|
|
||||||
|
def test_serialization_too_large(self):
|
||||||
|
h = websockets.FrameHeader(
|
||||||
|
payload_length=2 ** 64 + 1,
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
bytes(h)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
('0100', 0),
|
||||||
|
('017D', 125),
|
||||||
|
('017E007E', 126),
|
||||||
|
('017E007F', 127),
|
||||||
|
('017E008E', 142),
|
||||||
|
('017EFFFE', 65534),
|
||||||
|
('017EFFFF', 65535),
|
||||||
|
('017F0000000000010000', 65536),
|
||||||
|
('017F00000001FFFFFFFF', 8589934591),
|
||||||
|
('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1),
|
||||||
|
])
|
||||||
|
def test_deserialization_length(self, input, expected):
|
||||||
|
h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
|
||||||
|
assert h.payload_length == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
('0100', (False, None)),
|
||||||
|
('018000000000', (True, '00000000')),
|
||||||
|
('018012345678', (True, '12345678')),
|
||||||
|
])
|
||||||
|
def test_deserialization_masking(self, input, expected):
|
||||||
|
h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
|
||||||
|
assert h.mask == expected[0]
|
||||||
|
if h.mask:
|
||||||
|
assert h.masking_key == codecs.decode(expected[1], 'hex')
|
||||||
|
|
||||||
|
def test_equality(self):
|
||||||
|
h = websockets.FrameHeader(mask=True, masking_key=b'1234')
|
||||||
|
h2 = websockets.FrameHeader(mask=True, masking_key=b'1234')
|
||||||
|
assert h == h2
|
||||||
|
|
||||||
|
h = websockets.FrameHeader(fin=True)
|
||||||
|
h2 = websockets.FrameHeader(fin=False)
|
||||||
|
assert h != h2
|
||||||
|
|
||||||
|
assert h != 'foobar'
|
||||||
|
|
||||||
|
def test_roundtrip(self):
|
||||||
|
def round(*args, **kwargs):
|
||||||
|
h = websockets.FrameHeader(*args, **kwargs)
|
||||||
|
h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h)))
|
||||||
|
assert h == h2
|
||||||
|
|
||||||
|
round()
|
||||||
|
round(fin=True)
|
||||||
|
round(rsv1=True)
|
||||||
|
round(rsv2=True)
|
||||||
|
round(rsv3=True)
|
||||||
|
round(payload_length=1)
|
||||||
|
round(payload_length=100)
|
||||||
|
round(payload_length=1000)
|
||||||
|
round(payload_length=10000)
|
||||||
|
round(opcode=websockets.OPCODE.PING)
|
||||||
|
round(masking_key=b"test")
|
||||||
|
|
||||||
|
def test_human_readable(self):
|
||||||
|
f = websockets.FrameHeader(
|
||||||
|
masking_key=b"test",
|
||||||
|
fin=True,
|
||||||
|
payload_length=10
|
||||||
|
)
|
||||||
|
assert repr(f)
|
||||||
|
|
||||||
|
f = websockets.FrameHeader()
|
||||||
|
assert repr(f)
|
||||||
|
|
||||||
|
def test_funky(self):
|
||||||
|
f = websockets.FrameHeader(masking_key=b"test", mask=False)
|
||||||
|
raw = bytes(f)
|
||||||
|
f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
|
||||||
|
assert not f2.mask
|
||||||
|
|
||||||
|
def test_violations(self):
|
||||||
|
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
|
||||||
|
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
|
||||||
|
|
||||||
|
def test_automask(self):
|
||||||
|
f = websockets.FrameHeader(mask=True)
|
||||||
|
assert f.masking_key
|
||||||
|
|
||||||
|
f = websockets.FrameHeader(masking_key=b"foob")
|
||||||
|
assert f.mask
|
||||||
|
|
||||||
|
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
|
||||||
|
assert not f.mask
|
||||||
|
assert f.masking_key
|
||||||
|
|
||||||
|
|
||||||
|
class TestFrame(object):
|
||||||
|
def test_equality(self):
|
||||||
|
f = websockets.Frame(payload=b'1234')
|
||||||
|
f2 = websockets.Frame(payload=b'1234')
|
||||||
|
assert f == f2
|
||||||
|
|
||||||
|
assert f != b'1234'
|
||||||
|
|
||||||
|
def test_roundtrip(self):
|
||||||
|
def round(*args, **kwargs):
|
||||||
|
f = websockets.Frame(*args, **kwargs)
|
||||||
|
raw = bytes(f)
|
||||||
|
f2 = websockets.Frame.from_file(tutils.treader(raw))
|
||||||
|
assert f == f2
|
||||||
|
round(b"test")
|
||||||
|
round(b"test", fin=1)
|
||||||
|
round(b"test", rsv1=1)
|
||||||
|
round(b"test", opcode=websockets.OPCODE.PING)
|
||||||
|
round(b"test", masking_key=b"test")
|
||||||
|
|
||||||
|
def test_human_readable(self):
|
||||||
|
f = websockets.Frame()
|
||||||
|
assert repr(f)
|
||||||
|
|
||||||
|
f = websockets.Frame(b"foobar")
|
||||||
|
assert "foobar" in repr(f)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("masked", [True, False])
|
||||||
|
@pytest.mark.parametrize("length", [100, 50000, 150000])
|
||||||
|
def test_serialization_bijection(self, masked, length):
|
||||||
|
frame = websockets.Frame(
|
||||||
|
os.urandom(length),
|
||||||
|
fin=True,
|
||||||
|
opcode=websockets.OPCODE.TEXT,
|
||||||
|
mask=int(masked),
|
||||||
|
masking_key=(os.urandom(4) if masked else None)
|
||||||
|
)
|
||||||
|
serialized = bytes(frame)
|
||||||
|
assert frame == websockets.Frame.from_bytes(serialized)
|
23
test/netlib/websockets/test_masker.py
Normal file
23
test/netlib/websockets/test_masker.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import codecs
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from netlib import websockets
|
||||||
|
|
||||||
|
|
||||||
|
class TestMasker(object):
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
([b"a"], '00'),
|
||||||
|
([b"four"], '070d1616'),
|
||||||
|
([b"fourf"], '070d161607'),
|
||||||
|
([b"fourfive"], '070d1616070b1501'),
|
||||||
|
([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'),
|
||||||
|
([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa
|
||||||
|
])
|
||||||
|
def test_masker(self, input, expected):
|
||||||
|
m = websockets.Masker(b"abcd")
|
||||||
|
data = b"".join([m(t) for t in input])
|
||||||
|
assert data == codecs.decode(expected, 'hex')
|
||||||
|
|
||||||
|
data = websockets.Masker(b"abcd")(data)
|
||||||
|
assert data == b"".join(input)
|
105
test/netlib/websockets/test_utils.py
Normal file
105
test/netlib/websockets/test_utils.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from netlib import http
|
||||||
|
from netlib import websockets
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtils(object):
|
||||||
|
|
||||||
|
def test_client_handshake_headers(self):
|
||||||
|
h = websockets.client_handshake_headers(version='42')
|
||||||
|
assert h['sec-websocket-version'] == '42'
|
||||||
|
|
||||||
|
h = websockets.client_handshake_headers(key='some-key')
|
||||||
|
assert h['sec-websocket-key'] == 'some-key'
|
||||||
|
|
||||||
|
h = websockets.client_handshake_headers(protocol='foobar')
|
||||||
|
assert h['sec-websocket-protocol'] == 'foobar'
|
||||||
|
|
||||||
|
h = websockets.client_handshake_headers(extensions='foo; bar')
|
||||||
|
assert h['sec-websocket-extensions'] == 'foo; bar'
|
||||||
|
|
||||||
|
def test_server_handshake_headers(self):
|
||||||
|
h = websockets.server_handshake_headers('some-key')
|
||||||
|
assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
|
||||||
|
assert 'sec-websocket-protocol' not in h
|
||||||
|
assert 'sec-websocket-extensions' not in h
|
||||||
|
|
||||||
|
h = websockets.server_handshake_headers('some-key', 'foobar', 'foo; bar')
|
||||||
|
assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
|
||||||
|
assert h['sec-websocket-protocol'] == 'foobar'
|
||||||
|
assert h['sec-websocket-extensions'] == 'foo; bar'
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], True),
|
||||||
|
([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'foobar')], True),
|
||||||
|
([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-KeY', b'foobar')], True),
|
||||||
|
([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-AccePt', b'foobar')], True),
|
||||||
|
([(b'connection', b'foo'), (b'upgrade', b'bar'), (b'sec-websocket-key', b'foobar')], False),
|
||||||
|
([(b'connection', b'upgrade'), (b'upgrade', b'websocket')], False),
|
||||||
|
([(b'connection', b'upgrade'), (b'sec-websocket-key', b'foobar')], False),
|
||||||
|
([(b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], False),
|
||||||
|
([], False),
|
||||||
|
])
|
||||||
|
def test_check_handshake(self, input, expected):
|
||||||
|
h = http.Headers(input)
|
||||||
|
assert websockets.check_handshake(h) == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
([(b'sec-websocket-version', b'13')], True),
|
||||||
|
([(b'Sec-WebSockeT-VerSion', b'13')], True),
|
||||||
|
([(b'sec-websocket-version', b'9')], False),
|
||||||
|
([(b'sec-websocket-version', b'42')], False),
|
||||||
|
([(b'sec-websocket-version', b'')], False),
|
||||||
|
([], False),
|
||||||
|
])
|
||||||
|
def test_check_client_version(self, input, expected):
|
||||||
|
h = http.Headers(input)
|
||||||
|
assert websockets.check_client_version(h) == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
('foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
|
||||||
|
(b'foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
|
||||||
|
])
|
||||||
|
def test_create_server_nonce(self, input, expected):
|
||||||
|
assert websockets.create_server_nonce(input) == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
([(b'sec-websocket-extensions', b'foo; bar')], 'foo; bar'),
|
||||||
|
([(b'Sec-WebSockeT-ExteNsionS', b'foo; bar')], 'foo; bar'),
|
||||||
|
([(b'sec-websocket-extensions', b'')], ''),
|
||||||
|
([], None),
|
||||||
|
])
|
||||||
|
def test_get_extensions(self, input, expected):
|
||||||
|
h = http.Headers(input)
|
||||||
|
assert websockets.get_extensions(h) == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
([(b'sec-websocket-protocol', b'foobar')], 'foobar'),
|
||||||
|
([(b'Sec-WebSockeT-ProTocoL', b'foobar')], 'foobar'),
|
||||||
|
([(b'sec-websocket-protocol', b'')], ''),
|
||||||
|
([], None),
|
||||||
|
])
|
||||||
|
def test_get_protocol(self, input, expected):
|
||||||
|
h = http.Headers(input)
|
||||||
|
assert websockets.get_protocol(h) == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
([(b'sec-websocket-key', b'foobar')], 'foobar'),
|
||||||
|
([(b'Sec-WebSockeT-KeY', b'foobar')], 'foobar'),
|
||||||
|
([(b'sec-websocket-key', b'')], ''),
|
||||||
|
([], None),
|
||||||
|
])
|
||||||
|
def test_get_client_key(self, input, expected):
|
||||||
|
h = http.Headers(input)
|
||||||
|
assert websockets.get_client_key(h) == expected
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input,expected", [
|
||||||
|
([(b'sec-websocket-accept', b'foobar')], 'foobar'),
|
||||||
|
([(b'Sec-WebSockeT-AccepT', b'foobar')], 'foobar'),
|
||||||
|
([(b'sec-websocket-accept', b'')], ''),
|
||||||
|
([], None),
|
||||||
|
])
|
||||||
|
def test_get_server_accept(self, input, expected):
|
||||||
|
h = http.Headers(input)
|
||||||
|
assert websockets.get_server_accept(h) == expected
|
@ -1,269 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from netlib.http.http1 import read_response, read_request
|
|
||||||
|
|
||||||
from netlib import tcp
|
|
||||||
from netlib import tutils
|
|
||||||
from netlib import websockets
|
|
||||||
from netlib.http import status_codes
|
|
||||||
from netlib.tutils import treq
|
|
||||||
from netlib import exceptions
|
|
||||||
|
|
||||||
from .. import tservers
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketsEchoHandler(tcp.BaseHandler):
|
|
||||||
|
|
||||||
def __init__(self, connection, address, server):
|
|
||||||
super(WebSocketsEchoHandler, self).__init__(
|
|
||||||
connection, address, server
|
|
||||||
)
|
|
||||||
self.protocol = websockets.WebsocketsProtocol()
|
|
||||||
self.handshake_done = False
|
|
||||||
|
|
||||||
def handle(self):
|
|
||||||
while True:
|
|
||||||
if not self.handshake_done:
|
|
||||||
self.handshake()
|
|
||||||
else:
|
|
||||||
self.read_next_message()
|
|
||||||
|
|
||||||
def read_next_message(self):
|
|
||||||
frame = websockets.Frame.from_file(self.rfile)
|
|
||||||
self.on_message(frame.payload)
|
|
||||||
|
|
||||||
def send_message(self, message):
|
|
||||||
frame = websockets.Frame.default(message, from_client=False)
|
|
||||||
frame.to_file(self.wfile)
|
|
||||||
|
|
||||||
def handshake(self):
|
|
||||||
|
|
||||||
req = read_request(self.rfile)
|
|
||||||
key = self.protocol.check_client_handshake(req.headers)
|
|
||||||
|
|
||||||
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
|
|
||||||
self.wfile.write(preamble.encode() + b"\r\n")
|
|
||||||
headers = self.protocol.server_handshake_headers(key)
|
|
||||||
self.wfile.write(str(headers) + "\r\n")
|
|
||||||
self.wfile.flush()
|
|
||||||
self.handshake_done = True
|
|
||||||
|
|
||||||
def on_message(self, message):
|
|
||||||
if message is not None:
|
|
||||||
self.send_message(message)
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketsClient(tcp.TCPClient):
|
|
||||||
|
|
||||||
def __init__(self, address, source_address=None):
|
|
||||||
super(WebSocketsClient, self).__init__(address, source_address)
|
|
||||||
self.protocol = websockets.WebsocketsProtocol()
|
|
||||||
self.client_nonce = None
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
super(WebSocketsClient, self).connect()
|
|
||||||
|
|
||||||
preamble = b'GET / HTTP/1.1'
|
|
||||||
self.wfile.write(preamble + b"\r\n")
|
|
||||||
headers = self.protocol.client_handshake_headers()
|
|
||||||
self.client_nonce = headers["sec-websocket-key"].encode("ascii")
|
|
||||||
self.wfile.write(bytes(headers) + b"\r\n")
|
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
resp = read_response(self.rfile, treq(method=b"GET"))
|
|
||||||
server_nonce = self.protocol.check_server_handshake(resp.headers)
|
|
||||||
|
|
||||||
if not server_nonce == self.protocol.create_server_nonce(self.client_nonce):
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def read_next_message(self):
|
|
||||||
return websockets.Frame.from_file(self.rfile).payload
|
|
||||||
|
|
||||||
def send_message(self, message):
|
|
||||||
frame = websockets.Frame.default(message, from_client=True)
|
|
||||||
frame.to_file(self.wfile)
|
|
||||||
|
|
||||||
|
|
||||||
class TestWebSockets(tservers.ServerTestBase):
|
|
||||||
handler = WebSocketsEchoHandler
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.protocol = websockets.WebsocketsProtocol()
|
|
||||||
|
|
||||||
def random_bytes(self, n=100):
|
|
||||||
return os.urandom(n)
|
|
||||||
|
|
||||||
def echo(self, msg):
|
|
||||||
client = WebSocketsClient(("127.0.0.1", self.port))
|
|
||||||
client.connect()
|
|
||||||
client.send_message(msg)
|
|
||||||
response = client.read_next_message()
|
|
||||||
assert response == msg
|
|
||||||
|
|
||||||
def test_simple_echo(self):
|
|
||||||
self.echo(b"hello I'm the client")
|
|
||||||
|
|
||||||
def test_frame_sizes(self):
|
|
||||||
# length can fit in the the 7 bit payload length
|
|
||||||
small_msg = self.random_bytes(100)
|
|
||||||
# 50kb, sligthly larger than can fit in a 7 bit int
|
|
||||||
medium_msg = self.random_bytes(50000)
|
|
||||||
# 150kb, slightly larger than can fit in a 16 bit int
|
|
||||||
large_msg = self.random_bytes(150000)
|
|
||||||
|
|
||||||
self.echo(small_msg)
|
|
||||||
self.echo(medium_msg)
|
|
||||||
self.echo(large_msg)
|
|
||||||
|
|
||||||
def test_default_builder(self):
|
|
||||||
"""
|
|
||||||
default builder should always generate valid frames
|
|
||||||
"""
|
|
||||||
msg = self.random_bytes()
|
|
||||||
assert websockets.Frame.default(msg, from_client=True)
|
|
||||||
assert websockets.Frame.default(msg, from_client=False)
|
|
||||||
|
|
||||||
def test_serialization_bijection(self):
|
|
||||||
"""
|
|
||||||
Ensure that various frame types can be serialized/deserialized back
|
|
||||||
and forth between to_bytes() and from_bytes()
|
|
||||||
"""
|
|
||||||
for is_client in [True, False]:
|
|
||||||
for num_bytes in [100, 50000, 150000]:
|
|
||||||
frame = websockets.Frame.default(
|
|
||||||
self.random_bytes(num_bytes), is_client
|
|
||||||
)
|
|
||||||
frame2 = websockets.Frame.from_bytes(
|
|
||||||
frame.to_bytes()
|
|
||||||
)
|
|
||||||
assert frame == frame2
|
|
||||||
|
|
||||||
bytes = b'\x81\x03cba'
|
|
||||||
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
|
|
||||||
|
|
||||||
def test_check_server_handshake(self):
|
|
||||||
headers = self.protocol.server_handshake_headers("key")
|
|
||||||
assert self.protocol.check_server_handshake(headers)
|
|
||||||
headers["Upgrade"] = "not_websocket"
|
|
||||||
assert not self.protocol.check_server_handshake(headers)
|
|
||||||
|
|
||||||
def test_check_client_handshake(self):
|
|
||||||
headers = self.protocol.client_handshake_headers("key")
|
|
||||||
assert self.protocol.check_client_handshake(headers) == "key"
|
|
||||||
headers["Upgrade"] = "not_websocket"
|
|
||||||
assert not self.protocol.check_client_handshake(headers)
|
|
||||||
|
|
||||||
|
|
||||||
class BadHandshakeHandler(WebSocketsEchoHandler):
|
|
||||||
|
|
||||||
def handshake(self):
|
|
||||||
|
|
||||||
client_hs = read_request(self.rfile)
|
|
||||||
self.protocol.check_client_handshake(client_hs.headers)
|
|
||||||
|
|
||||||
preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
|
|
||||||
self.wfile.write(preamble.encode())
|
|
||||||
headers = self.protocol.server_handshake_headers(b"malformed key")
|
|
||||||
self.wfile.write(bytes(headers) + b"\r\n")
|
|
||||||
self.wfile.flush()
|
|
||||||
self.handshake_done = True
|
|
||||||
|
|
||||||
|
|
||||||
class TestBadHandshake(tservers.ServerTestBase):
|
|
||||||
|
|
||||||
"""
|
|
||||||
Ensure that the client disconnects if the server handshake is malformed
|
|
||||||
"""
|
|
||||||
handler = BadHandshakeHandler
|
|
||||||
|
|
||||||
def test(self):
|
|
||||||
with tutils.raises(exceptions.TcpDisconnect):
|
|
||||||
client = WebSocketsClient(("127.0.0.1", self.port))
|
|
||||||
client.connect()
|
|
||||||
client.send_message(b"hello")
|
|
||||||
|
|
||||||
|
|
||||||
class TestFrameHeader:
|
|
||||||
|
|
||||||
def test_roundtrip(self):
|
|
||||||
def round(*args, **kwargs):
|
|
||||||
f = websockets.FrameHeader(*args, **kwargs)
|
|
||||||
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
|
|
||||||
assert f == f2
|
|
||||||
round()
|
|
||||||
round(fin=1)
|
|
||||||
round(rsv1=1)
|
|
||||||
round(rsv2=1)
|
|
||||||
round(rsv3=1)
|
|
||||||
round(payload_length=1)
|
|
||||||
round(payload_length=100)
|
|
||||||
round(payload_length=1000)
|
|
||||||
round(payload_length=10000)
|
|
||||||
round(opcode=websockets.OPCODE.PING)
|
|
||||||
round(masking_key=b"test")
|
|
||||||
|
|
||||||
def test_human_readable(self):
|
|
||||||
f = websockets.FrameHeader(
|
|
||||||
masking_key=b"test",
|
|
||||||
fin=True,
|
|
||||||
payload_length=10
|
|
||||||
)
|
|
||||||
assert repr(f)
|
|
||||||
f = websockets.FrameHeader()
|
|
||||||
assert repr(f)
|
|
||||||
|
|
||||||
def test_funky(self):
|
|
||||||
f = websockets.FrameHeader(masking_key=b"test", mask=False)
|
|
||||||
raw = bytes(f)
|
|
||||||
f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
|
|
||||||
assert not f2.mask
|
|
||||||
|
|
||||||
def test_violations(self):
|
|
||||||
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
|
|
||||||
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
|
|
||||||
|
|
||||||
def test_automask(self):
|
|
||||||
f = websockets.FrameHeader(mask=True)
|
|
||||||
assert f.masking_key
|
|
||||||
|
|
||||||
f = websockets.FrameHeader(masking_key=b"foob")
|
|
||||||
assert f.mask
|
|
||||||
|
|
||||||
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
|
|
||||||
assert not f.mask
|
|
||||||
assert f.masking_key
|
|
||||||
|
|
||||||
|
|
||||||
class TestFrame:
|
|
||||||
|
|
||||||
def test_roundtrip(self):
|
|
||||||
def round(*args, **kwargs):
|
|
||||||
f = websockets.Frame(*args, **kwargs)
|
|
||||||
raw = bytes(f)
|
|
||||||
f2 = websockets.Frame.from_file(tutils.treader(raw))
|
|
||||||
assert f == f2
|
|
||||||
round(b"test")
|
|
||||||
round(b"test", fin=1)
|
|
||||||
round(b"test", rsv1=1)
|
|
||||||
round(b"test", opcode=websockets.OPCODE.PING)
|
|
||||||
round(b"test", masking_key=b"test")
|
|
||||||
|
|
||||||
def test_human_readable(self):
|
|
||||||
f = websockets.Frame()
|
|
||||||
assert repr(f)
|
|
||||||
|
|
||||||
|
|
||||||
def test_masker():
|
|
||||||
tests = [
|
|
||||||
[b"a"],
|
|
||||||
[b"four"],
|
|
||||||
[b"fourf"],
|
|
||||||
[b"fourfive"],
|
|
||||||
[b"a", b"aasdfasdfa", b"asdf"],
|
|
||||||
[b"a" * 50, b"aasdfasdfa", b"asdf"],
|
|
||||||
]
|
|
||||||
for i in tests:
|
|
||||||
m = websockets.Masker(b"abcd")
|
|
||||||
data = b"".join([m(t) for t in i])
|
|
||||||
data2 = websockets.Masker(b"abcd")(data)
|
|
||||||
assert data2 == b"".join(i)
|
|
Loading…
Reference in New Issue
Block a user