Merge pull request #1488 from mitmproxy/websockets

add WebSockets support
This commit is contained in:
Thomas Kriechbaumer 2016-09-01 10:39:57 +02:00 committed by GitHub
commit 55d938b880
24 changed files with 965 additions and 487 deletions

View File

@ -126,6 +126,18 @@ HTTP Events
:param HTTPFlow flow: The flow containing the error. :param HTTPFlow flow: The flow containing the error.
It is guaranteed to have non-None ``error`` attribute. It is guaranteed to have non-None ``error`` attribute.
WebSockets Events
^^^^^^^^^^^^^^^^^
.. py:function:: websockets_handshake(context, flow)
Called when a client wants to establish a WebSockets connection.
The WebSockets-specific headers can be manipulated to manipulate the handshake.
The ``flow`` object is guaranteed to have a non-None ``request`` attribute.
:param HTTPFlow flow: The flow containing the request which has been received.
The object is guaranteed to have a non-None ``request`` attribute.
TCP Events TCP Events
^^^^^^^^^^ ^^^^^^^^^^

View File

@ -28,6 +28,8 @@ Events = frozenset([
"response", "response",
"responseheaders", "responseheaders",
"websockets_handshake",
"next_layer", "next_layer",
"error", "error",

View File

@ -334,6 +334,10 @@ class FlowMaster(controller.Master):
self.client_playback.clear(f) self.client_playback.clear(f)
return f return f
@controller.handler
def websockets_handshake(self, f):
return f
def handle_intercept(self, f): def handle_intercept(self, f):
self.state.update_flow(f) self.state.update_flow(f)

View File

@ -29,8 +29,10 @@ from __future__ import absolute_import, print_function, division
from .base import Layer, ServerConnectionMixin from .base import Layer, ServerConnectionMixin
from .http import UpstreamConnectLayer from .http import UpstreamConnectLayer
from .http import HttpLayer
from .http1 import Http1Layer from .http1 import Http1Layer
from .http2 import Http2Layer from .http2 import Http2Layer
from .websockets import WebSocketsLayer
from .rawtcp import RawTCPLayer from .rawtcp import RawTCPLayer
from .tls import TlsClientHello from .tls import TlsClientHello
from .tls import TlsLayer from .tls import TlsLayer
@ -40,7 +42,9 @@ __all__ = [
"Layer", "ServerConnectionMixin", "Layer", "ServerConnectionMixin",
"TlsLayer", "is_tls_record_magic", "TlsClientHello", "TlsLayer", "is_tls_record_magic", "TlsClientHello",
"UpstreamConnectLayer", "UpstreamConnectLayer",
"HttpLayer",
"Http1Layer", "Http1Layer",
"Http2Layer", "Http2Layer",
"WebSocketsLayer",
"RawTCPLayer", "RawTCPLayer",
] ]

View File

@ -7,12 +7,14 @@ import traceback
import h2.exceptions import h2.exceptions
import six import six
import netlib.exceptions
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import models from mitmproxy import models
from mitmproxy.protocol import base from mitmproxy.protocol import base
import netlib.exceptions
from netlib import http from netlib import http
from netlib import tcp from netlib import tcp
from netlib import websockets
class _HttpTransmissionLayer(base.Layer): class _HttpTransmissionLayer(base.Layer):
@ -189,6 +191,11 @@ class HttpLayer(base.Layer):
self.process_request_hook(flow) self.process_request_hook(flow)
try: try:
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
# we only support RFC6455 with WebSockets version 13
# allow inline scripts to manupulate the client handshake
self.channel.ask("websockets_handshake", flow)
if not flow.response: if not flow.response:
self.establish_server_connection( self.establish_server_connection(
flow.request.host, flow.request.host,
@ -212,7 +219,7 @@ class HttpLayer(base.Layer):
# It may be useful to pass additional args (such as the upgrade header) # It may be useful to pass additional args (such as the upgrade header)
# to next_layer in the future # to next_layer in the future
if flow.response.status_code == 101: if flow.response.status_code == 101:
layer = self.ctx.next_layer(self) layer = self.ctx.next_layer(self, flow)
layer() layer()
return return

View File

@ -0,0 +1,108 @@
from __future__ import absolute_import, print_function, division
import socket
import struct
from OpenSSL import SSL
from mitmproxy import exceptions
from mitmproxy.protocol import base
import netlib.exceptions
from netlib import tcp
from netlib import websockets
class WebSocketsLayer(base.Layer):
"""
WebSockets layer to intercept, modify, and forward WebSockets connections
Only version 13 is supported (as specified in RFC6455)
Only HTTP/1.1-initiated connections are supported.
The client starts by sending an Upgrade-request.
In order to determine the handshake and negotiate the correct protocol
and extensions, the Upgrade-request is forwarded to the server.
The response from the server is then parsed and negotiated settings are extracted.
Finally the handshake is completed by forwarding the server-response to the client.
After that, only WebSockets frames are exchanged.
PING/PONG frames pass through and must be answered by the other endpoint.
CLOSE frames are forwarded before this WebSocketsLayer terminates.
This layer is transparent to any negotiated extensions.
This layer is transparent to any negotiated subprotocols.
Only raw frames are forwarded to the other endpoint.
"""
def __init__(self, ctx, flow):
super(WebSocketsLayer, self).__init__(ctx)
self._flow = flow
self.client_key = websockets.get_client_key(self._flow.request.headers)
self.client_protocol = websockets.get_protocol(self._flow.request.headers)
self.client_extensions = websockets.get_extensions(self._flow.request.headers)
self.server_accept = websockets.get_server_accept(self._flow.response.headers)
self.server_protocol = websockets.get_protocol(self._flow.response.headers)
self.server_extensions = websockets.get_extensions(self._flow.response.headers)
def _handle_frame(self, frame, source_conn, other_conn, is_server):
self.log(
"WebSockets Frame received from {}".format("server" if is_server else "client"),
"debug",
[repr(frame)]
)
if frame.header.opcode & 0x8 == 0:
# forward the data frame to the other side
other_conn.send(bytes(frame))
self.log("WebSockets frame received by {}: {}".format(is_server, frame), "debug")
elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
# just forward the ping/pong to the other side
other_conn.send(bytes(frame))
elif frame.header.opcode == websockets.OPCODE.CLOSE:
other_conn.send(bytes(frame))
code = '(status code missing)'
msg = None
reason = '(message missing)'
if len(frame.payload) >= 2:
code, = struct.unpack('!H', frame.payload[:2])
msg = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
if len(frame.payload) > 2:
reason = frame.payload[2:]
self.log("WebSockets connection closed: {} {}, {}".format(code, msg, reason), "info")
# close the connection
return False
else:
# unknown frame - just forward it
other_conn.send(bytes(frame))
# continue the connection
return True
def __call__(self):
client = self.client_conn.connection
server = self.server_conn.connection
conns = [client, server]
try:
while not self.channel.should_exit.is_set():
r = tcp.ssl_read_select(conns, 1)
for conn in r:
source_conn = self.client_conn if conn == client else self.server_conn
other_conn = self.server_conn if conn == client else self.client_conn
is_server = (conn == self.server_conn.connection)
frame = websockets.Frame.from_file(source_conn.rfile)
if not self._handle_frame(frame, source_conn, other_conn, is_server):
return
except (socket.error, netlib.exceptions.TcpException, SSL.Error) as e:
self.log("WebSockets connection closed unexpectedly by {}: {}".format(
"server" if is_server else "client", repr(e)), "info")
except Exception as e: # pragma: no cover
raise exceptions.ProtocolException("Error in WebSockets connection: {}".format(repr(e)))

View File

@ -4,6 +4,7 @@ import sys
import six import six
from netlib import websockets
import netlib.exceptions import netlib.exceptions
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import protocol from mitmproxy import protocol
@ -32,7 +33,7 @@ class RootContext(object):
self.channel = channel self.channel = channel
self.config = config self.config = config
def next_layer(self, top_layer): def next_layer(self, top_layer, flow=None):
""" """
This function determines the next layer in the protocol stack. This function determines the next layer in the protocol stack.
@ -42,10 +43,22 @@ class RootContext(object):
Returns: Returns:
The next layer The next layer
""" """
layer = self._next_layer(top_layer) layer = self._next_layer(top_layer, flow)
return self.channel.ask("next_layer", layer) return self.channel.ask("next_layer", layer)
def _next_layer(self, top_layer): def _next_layer(self, top_layer, flow):
if flow is not None:
# We already have a flow, try to derive the next information from it
# Check for WebSockets handshake
is_websockets = (
flow and
websockets.check_handshake(flow.request.headers) and
websockets.check_handshake(flow.response.headers)
)
if isinstance(top_layer, protocol.HttpLayer) and is_websockets:
return protocol.WebSocketsLayer(top_layer, flow)
try: try:
d = top_layer.client_conn.rfile.peek(3) d = top_layer.client_conn.rfile.peek(3)
except netlib.exceptions.TcpException as e: except netlib.exceptions.TcpException as e:

View File

@ -1,11 +1,37 @@
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from .frame import FrameHeader, Frame, OPCODE
from .protocol import Masker, WebsocketsProtocol from .frame import FrameHeader
from .frame import Frame
from .frame import OPCODE
from .frame import CLOSE_REASON
from .masker import Masker
from .utils import MAGIC
from .utils import VERSION
from .utils import client_handshake_headers
from .utils import server_handshake_headers
from .utils import check_handshake
from .utils import check_client_version
from .utils import create_server_nonce
from .utils import get_extensions
from .utils import get_protocol
from .utils import get_client_key
from .utils import get_server_accept
__all__ = [ __all__ = [
"FrameHeader", "FrameHeader",
"Frame", "Frame",
"Masker",
"WebsocketsProtocol",
"OPCODE", "OPCODE",
"CLOSE_REASON",
"Masker",
"MAGIC",
"VERSION",
"client_handshake_headers",
"server_handshake_headers",
"check_handshake",
"check_client_version",
"create_server_nonce",
"get_extensions",
"get_protocol",
"get_client_key",
"get_server_accept",
] ]

View File

@ -2,7 +2,6 @@ from __future__ import absolute_import
import os import os
import struct import struct
import io import io
import warnings
import six import six
@ -10,7 +9,7 @@ from netlib import tcp
from netlib import strutils from netlib import strutils
from netlib import utils from netlib import utils
from netlib import human from netlib import human
from netlib.websockets import protocol from .masker import Masker
MAX_16_BIT_INT = (1 << 16) MAX_16_BIT_INT = (1 << 16)
@ -18,6 +17,7 @@ MAX_64_BIT_INT = (1 << 64)
DEFAULT = object() DEFAULT = object()
# RFC 6455, Section 5.2 - Base Framing Protocol
OPCODE = utils.BiDi( OPCODE = utils.BiDi(
CONTINUE=0x00, CONTINUE=0x00,
TEXT=0x01, TEXT=0x01,
@ -27,6 +27,23 @@ OPCODE = utils.BiDi(
PONG=0x0a PONG=0x0a
) )
# RFC 6455, Section 7.4.1 - Defined Status Codes
CLOSE_REASON = utils.BiDi(
NORMAL_CLOSURE=1000,
GOING_AWAY=1001,
PROTOCOL_ERROR=1002,
UNSUPPORTED_DATA=1003,
RESERVED=1004,
RESERVED_NO_STATUS=1005,
RESERVED_ABNORMAL_CLOSURE=1006,
INVALID_PAYLOAD_DATA=1007,
POLICY_VIOLATION=1008,
MESSAGE_TOO_BIG=1009,
MANDATORY_EXTENSION=1010,
INTERNAL_ERROR=1011,
RESERVED_TLS_HANDHSAKE_FAILED=1015,
)
class FrameHeader(object): class FrameHeader(object):
@ -103,10 +120,6 @@ class FrameHeader(object):
vals.append(" %s" % human.pretty_size(self.payload_length)) vals.append(" %s" % human.pretty_size(self.payload_length))
return "".join(vals) return "".join(vals)
def human_readable(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return repr(self)
def __bytes__(self): def __bytes__(self):
first_byte = utils.setbit(0, 7, self.fin) first_byte = utils.setbit(0, 7, self.fin)
first_byte = utils.setbit(first_byte, 6, self.rsv1) first_byte = utils.setbit(first_byte, 6, self.rsv1)
@ -128,6 +141,9 @@ class FrameHeader(object):
# '!Q' = pack as 64 bit unsigned long long # '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length # add 8 bytes extended payload length
b += struct.pack('!Q', self.payload_length) b += struct.pack('!Q', self.payload_length)
else:
raise ValueError("Payload length exceeds 64bit integer")
if self.masking_key: if self.masking_key:
b += self.masking_key b += self.masking_key
return b return b
@ -135,10 +151,6 @@ class FrameHeader(object):
if six.PY2: if six.PY2:
__str__ = __bytes__ __str__ = __bytes__
def to_bytes(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return bytes(self)
@classmethod @classmethod
def from_file(cls, fp): def from_file(cls, fp):
""" """
@ -151,19 +163,17 @@ class FrameHeader(object):
rsv1 = utils.getbit(first_byte, 6) rsv1 = utils.getbit(first_byte, 6)
rsv2 = utils.getbit(first_byte, 5) rsv2 = utils.getbit(first_byte, 5)
rsv3 = utils.getbit(first_byte, 4) rsv3 = utils.getbit(first_byte, 4)
# grab right-most 4 bits opcode = first_byte & 0xF
opcode = first_byte & 15
mask_bit = utils.getbit(second_byte, 7) mask_bit = utils.getbit(second_byte, 7)
# grab the next 7 bits length_code = second_byte & 0x7F
length_code = second_byte & 127
# payload_lengthy > 125 indicates you need to read more bytes # payload_length > 125 indicates you need to read more bytes
# to get the actual payload length # to get the actual payload length
if length_code <= 125: if length_code <= 125:
payload_length = length_code payload_length = length_code
elif length_code == 126: elif length_code == 126:
payload_length, = struct.unpack("!H", fp.safe_read(2)) payload_length, = struct.unpack("!H", fp.safe_read(2))
elif length_code == 127: else: # length_code == 127:
payload_length, = struct.unpack("!Q", fp.safe_read(8)) payload_length, = struct.unpack("!Q", fp.safe_read(8))
# masking key only present if mask bit set # masking key only present if mask bit set
@ -191,31 +201,30 @@ class FrameHeader(object):
class Frame(object): class Frame(object):
""" """
Represents one websockets frame. Represents a single WebSockets frame.
Constructor takes human readable forms of the frame components Constructor takes human readable forms of the frame components.
from_bytes() is also avaliable. from_bytes() reads from a file-like object to create a new Frame.
WebSockets Frame as defined in RFC6455 WebSockets Frame as defined in RFC6455
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+ +-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length | |F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) | |I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) | |N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | | | |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 | | Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+ + - - - - - - - - - - - - - - - +-------------------------------+
| |Masking-key, if MASK set to 1 | | |Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+ +-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data | | Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - + +-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... : : Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... | | Payload Data continued ... |
+---------------------------------------------------------------+ +---------------------------------------------------------------+
""" """
def __init__(self, payload=b"", **kwargs): def __init__(self, payload=b"", **kwargs):
@ -223,27 +232,6 @@ class Frame(object):
kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
self.header = FrameHeader(**kwargs) self.header = FrameHeader(**kwargs)
@classmethod
def default(cls, message, from_client=False):
"""
Construct a basic websocket frame from some default values.
Creates a non-fragmented text frame.
"""
if from_client:
mask_bit = 1
masking_key = os.urandom(4)
else:
mask_bit = 0
masking_key = None
return cls(
message,
fin=1, # final frame
opcode=OPCODE.TEXT, # text
mask=mask_bit,
masking_key=masking_key,
)
@classmethod @classmethod
def from_bytes(cls, bytestring): def from_bytes(cls, bytestring):
""" """
@ -258,17 +246,13 @@ class Frame(object):
ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload) ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload)
return ret return ret
def human_readable(self):
warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning)
return repr(self)
def __bytes__(self): def __bytes__(self):
""" """
Serialize the frame to wire format. Returns a string. Serialize the frame to wire format. Returns a string.
""" """
b = bytes(self.header) b = bytes(self.header)
if self.header.masking_key: if self.header.masking_key:
b += protocol.Masker(self.header.masking_key)(self.payload) b += Masker(self.header.masking_key)(self.payload)
else: else:
b += self.payload b += self.payload
return b return b
@ -276,15 +260,6 @@ class Frame(object):
if six.PY2: if six.PY2:
__str__ = __bytes__ __str__ = __bytes__
def to_bytes(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return bytes(self)
def to_file(self, writer):
warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning)
writer.write(bytes(self))
writer.flush()
@classmethod @classmethod
def from_file(cls, fp): def from_file(cls, fp):
""" """
@ -297,20 +272,11 @@ class Frame(object):
payload = fp.safe_read(header.payload_length) payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key: if header.mask == 1 and header.masking_key:
payload = protocol.Masker(header.masking_key)(payload) payload = Masker(header.masking_key)(payload)
return cls( frame = cls(payload)
payload, frame.header = header
fin=header.fin, return frame
opcode=header.opcode,
mask=header.mask,
payload_length=header.payload_length,
masking_key=header.masking_key,
rsv1=header.rsv1,
rsv2=header.rsv2,
rsv3=header.rsv3,
length_code=header.length_code
)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Frame): if isinstance(other, Frame):

View File

@ -0,0 +1,33 @@
from __future__ import absolute_import
import six
class Masker(object):
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns.
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
def __init__(self, key):
self.key = key
self.offset = 0
def mask(self, offset, data):
result = bytearray(data)
for i in range(len(data)):
if six.PY2:
result[i] ^= ord(self.key[offset % 4])
else:
result[i] ^= self.key[offset % 4]
offset += 1
result = bytes(result)
return result
def __call__(self, data):
ret = self.mask(self.offset, data)
self.offset += len(ret)
return ret

View File

@ -1,112 +0,0 @@
"""
Colleciton of utility functions that implement small portions of the RFC6455
WebSockets Protocol Useful for building WebSocket clients and servers.
Emphassis is on readabilty, simplicity and modularity, not performance or
completeness
This is a work in progress and does not yet contain all the utilites need to
create fully complient client/servers #
Spec: https://tools.ietf.org/html/rfc6455
The magic sha that websocket servers must know to prove they understand
RFC6455
"""
from __future__ import absolute_import
import base64
import hashlib
import os
import six
from netlib import http, strutils
websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
class Masker(object):
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
def __init__(self, key):
self.key = key
self.offset = 0
def mask(self, offset, data):
result = bytearray(data)
if six.PY2:
for i in range(len(data)):
result[i] ^= ord(self.key[offset % 4])
offset += 1
result = str(result)
else:
for i in range(len(data)):
result[i] ^= self.key[offset % 4]
offset += 1
result = bytes(result)
return result
def __call__(self, data):
ret = self.mask(self.offset, data)
self.offset += len(ret)
return ret
class WebsocketsProtocol(object):
def __init__(self):
pass
@classmethod
def client_handshake_headers(self, key=None, version=VERSION):
"""
Create the headers for a valid HTTP upgrade request. If Key is not
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of http.Headers
"""
if not key:
key = base64.b64encode(os.urandom(16)).decode('ascii')
return http.Headers(
sec_websocket_key=key,
sec_websocket_version=version,
connection="Upgrade",
upgrade="websocket",
)
@classmethod
def server_handshake_headers(self, key):
"""
The server response is a valid HTTP 101 response.
"""
return http.Headers(
sec_websocket_accept=self.create_server_nonce(key),
connection="Upgrade",
upgrade="websocket"
)
@classmethod
def check_client_handshake(self, headers):
if headers.get("upgrade") != "websocket":
return
return headers.get("sec-websocket-key")
@classmethod
def check_server_handshake(self, headers):
if headers.get("upgrade") != "websocket":
return
return headers.get("sec-websocket-accept")
@classmethod
def create_server_nonce(self, client_nonce):
return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + websockets_magic).digest())

View File

@ -0,0 +1,90 @@
"""
Collection of WebSockets Protocol utility functions (RFC6455)
Spec: https://tools.ietf.org/html/rfc6455
"""
from __future__ import absolute_import
import base64
import hashlib
import os
from netlib import http, strutils
MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):
"""
Create the headers for a valid HTTP upgrade request. If Key is not
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of http.Headers
"""
if version is None:
version = VERSION
if key is None:
key = base64.b64encode(os.urandom(16)).decode('ascii')
h = http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_version=version,
sec_websocket_key=key,
)
if protocol is not None:
h['sec-websocket-protocol'] = protocol
if extensions is not None:
h['sec-websocket-extensions'] = extensions
return h
def server_handshake_headers(client_key, protocol=None, extensions=None):
"""
The server response is a valid HTTP 101 response.
Returns an instance of http.Headers
"""
h = http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_accept=create_server_nonce(client_key),
)
if protocol is not None:
h['sec-websocket-protocol'] = protocol
if extensions is not None:
h['sec-websocket-extensions'] = extensions
return h
def check_handshake(headers):
return (
"upgrade" in headers.get("connection", "").lower() and
headers.get("upgrade", "").lower() == "websocket" and
(headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None)
)
def create_server_nonce(client_nonce):
return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest())
def check_client_version(headers):
return headers.get("sec-websocket-version", "") == VERSION
def get_extensions(headers):
return headers.get("sec-websocket-extensions", None)
def get_protocol(headers):
return headers.get("sec-websocket-protocol", None)
def get_client_key(headers):
return headers.get("sec-websocket-key", None)
def get_server_accept(headers):
return headers.get("sec-websocket-accept", None)

View File

@ -198,7 +198,7 @@ class Response(_HTTPMessage):
1, 1,
StatusCode(101) StatusCode(101)
) )
headers = netlib.websockets.WebsocketsProtocol.server_handshake_headers( headers = netlib.websockets.server_handshake_headers(
settings.websocket_key settings.websocket_key
) )
for i in headers.fields: for i in headers.fields:
@ -310,7 +310,7 @@ class Request(_HTTPMessage):
1, 1,
Method("get") Method("get")
) )
for i in netlib.websockets.WebsocketsProtocol.client_handshake_headers().fields: for i in netlib.websockets.client_handshake_headers().fields:
if not get_header(i[0], self.headers): if not get_header(i[0], self.headers):
tokens.append( tokens.append(
Header( Header(

View File

@ -139,7 +139,7 @@ class WebsocketFrameReader(basethread.BaseThread):
except exceptions.TcpDisconnect: except exceptions.TcpDisconnect:
return return
self.frames_queue.put(frm) self.frames_queue.put(frm)
log("<< %s" % frm.header.human_readable()) log("<< %s" % repr(frm.header))
if self.ws_read_limit is not None: if self.ws_read_limit is not None:
self.ws_read_limit -= 1 self.ws_read_limit -= 1
starttime = time.time() starttime = time.time()

View File

@ -173,12 +173,13 @@ class PathodHandler(tcp.BaseHandler):
retlog["cipher"] = self.get_current_cipher() retlog["cipher"] = self.get_current_cipher()
m = utils.MemBool() m = utils.MemBool()
websocket_key = websockets.WebsocketsProtocol.check_client_handshake(headers)
self.settings.websocket_key = websocket_key valid_websockets_handshake = websockets.check_handshake(headers)
self.settings.websocket_key = websockets.get_client_key(headers)
# If this is a websocket initiation, we respond with a proper # If this is a websocket initiation, we respond with a proper
# server response, unless over-ridden. # server response, unless over-ridden.
if websocket_key: if valid_websockets_handshake:
anchor_gen = language.parse_pathod("ws") anchor_gen = language.parse_pathod("ws")
else: else:
anchor_gen = None anchor_gen = None
@ -225,7 +226,7 @@ class PathodHandler(tcp.BaseHandler):
spec, spec,
lg lg
) )
if nexthandler and websocket_key: if nexthandler and valid_websockets_handshake:
self.protocol = protocols.websockets.WebsocketsProtocol(self) self.protocol = protocols.websockets.WebsocketsProtocol(self)
return self.protocol.handle_websocket, retlog return self.protocol.handle_websocket, retlog
else: else:

View File

@ -20,7 +20,7 @@ class WebsocketsProtocol:
lg("Error reading websocket frame: %s" % e) lg("Error reading websocket frame: %s" % e)
return None, None return None, None
ended = time.time() ended = time.time()
lg(frm.human_readable()) lg(repr(frm))
retlog = dict( retlog = dict(
type="inbound", type="inbound",
protocol="websockets", protocol="websockets",

View File

View File

@ -1,7 +1,9 @@
from __future__ import (absolute_import, print_function, division)
from netlib.http import http1 from netlib.http import http1
from netlib.tcp import TCPClient from netlib.tcp import TCPClient
from netlib.tutils import treq from netlib.tutils import treq
from . import tutils, tservers from .. import tutils, tservers
class TestHTTPFlow(object): class TestHTTPFlow(object):

View File

@ -13,11 +13,11 @@ from mitmproxy import options
from mitmproxy.proxy.config import ProxyConfig from mitmproxy.proxy.config import ProxyConfig
import netlib import netlib
from ..netlib import tservers as netlib_tservers from ...netlib import tservers as netlib_tservers
from netlib.exceptions import HttpException from netlib.exceptions import HttpException
from netlib.http.http2 import framereader from netlib.http.http2 import framereader
from . import tservers from .. import tservers
import logging import logging
logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING) logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING)

View File

@ -0,0 +1,299 @@
from __future__ import absolute_import, print_function, division
import pytest
import os
import tempfile
import traceback
from mitmproxy import options
from mitmproxy.proxy.config import ProxyConfig
import netlib
from netlib import http
from ...netlib import tservers as netlib_tservers
from .. import tservers
from netlib import websockets
class _WebSocketsServerBase(netlib_tservers.ServerTestBase):
class handler(netlib.tcp.BaseHandler):
def handle(self):
try:
request = http.http1.read_request(self.rfile)
assert websockets.check_handshake(request.headers)
response = http.Response(
"HTTP/1.1",
101,
reason=http.status_codes.RESPONSES.get(101),
headers=http.Headers(
connection='upgrade',
upgrade='websocket',
sec_websocket_accept=b'',
),
content=b'',
)
self.wfile.write(http.http1.assemble_response(response))
self.wfile.flush()
self.server.handle_websockets(self.rfile, self.wfile)
except:
traceback.print_exc()
class _WebSocketsTestBase(object):
@classmethod
def setup_class(cls):
opts = cls.get_options()
cls.config = ProxyConfig(opts)
tmaster = tservers.TestMaster(opts, cls.config)
tmaster.start_app(options.APP_HOST, options.APP_PORT)
cls.proxy = tservers.ProxyThread(tmaster)
cls.proxy.start()
@classmethod
def teardown_class(cls):
cls.proxy.shutdown()
@classmethod
def get_options(cls):
opts = options.Options(
listen_port=0,
no_upstream_cert=False,
ssl_insecure=True
)
opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
return opts
@property
def master(self):
return self.proxy.tmaster
def setup(self):
self.master.clear_log()
self.master.state.clear()
self.server.server.handle_websockets = self.handle_websockets
def _setup_connection(self):
client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port))
client.connect()
request = http.Request(
"authority",
"CONNECT",
"",
"localhost",
self.server.server.address.port,
"",
"HTTP/1.1",
content=b'')
client.wfile.write(http.http1.assemble_request(request))
client.wfile.flush()
response = http.http1.read_response(client.rfile, request)
if self.ssl:
client.convert_to_ssl()
assert client.ssl_established
request = http.Request(
"relative",
"GET",
"http",
"localhost",
self.server.server.address.port,
"/ws",
"HTTP/1.1",
headers=http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_version="13",
sec_websocket_key="1234",
),
content=b'')
client.wfile.write(http.http1.assemble_request(request))
client.wfile.flush()
response = http.http1.read_response(client.rfile, request)
assert websockets.check_handshake(response.headers)
return client
class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase):
@classmethod
def setup_class(cls):
_WebSocketsTestBase.setup_class()
_WebSocketsServerBase.setup_class(ssl=cls.ssl)
@classmethod
def teardown_class(cls):
_WebSocketsTestBase.teardown_class()
_WebSocketsServerBase.teardown_class()
class TestSimple(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.flush()
def test_simple(self):
client = self._setup_connection()
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'server-foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
client.wfile.flush()
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'client-foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
client.wfile.flush()
class TestSimpleTLS(_WebSocketsTest):
ssl = True
@classmethod
def handle_websockets(cls, rfile, wfile):
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.flush()
def test_simple_tls(self):
client = self._setup_connection()
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'server-foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
client.wfile.flush()
frame = websockets.Frame.from_file(client.rfile)
assert frame.payload == b'client-foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
client.wfile.flush()
class TestPing(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
wfile.flush()
def test_ping(self):
client = self._setup_connection()
frame = websockets.Frame.from_file(client.rfile)
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar'
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
client.wfile.flush()
frame = websockets.Frame.from_file(client.rfile)
assert frame.header.opcode == websockets.OPCODE.TEXT
assert frame.payload == b'pong-received'
class TestPong(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar'
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
wfile.flush()
def test_pong(self):
client = self._setup_connection()
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
client.wfile.flush()
frame = websockets.Frame.from_file(client.rfile)
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
class TestClose(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.flush()
with pytest.raises(netlib.exceptions.TcpDisconnect):
websockets.Frame.from_file(rfile)
def test_close(self):
client = self._setup_connection()
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
client.wfile.flush()
with pytest.raises(netlib.exceptions.TcpDisconnect):
websockets.Frame.from_file(client.rfile)
def test_close_payload_1(self):
client = self._setup_connection()
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
client.wfile.flush()
with pytest.raises(netlib.exceptions.TcpDisconnect):
websockets.Frame.from_file(client.rfile)
def test_close_payload_2(self):
client = self._setup_connection()
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
client.wfile.flush()
with pytest.raises(netlib.exceptions.TcpDisconnect):
websockets.Frame.from_file(client.rfile)
class TestInvalidFrame(_WebSocketsTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar')))
wfile.flush()
def test_invalid_frame(self):
client = self._setup_connection()
# with pytest.raises(netlib.exceptions.TcpDisconnect):
frame = websockets.Frame.from_file(client.rfile)
assert frame.header.opcode == 15
assert frame.payload == b'foobar'

View File

@ -0,0 +1,164 @@
import os
import codecs
import pytest
from netlib import websockets
from netlib import tutils
class TestFrameHeader(object):
@pytest.mark.parametrize("input,expected", [
(0, '0100'),
(125, '017D'),
(126, '017E007E'),
(127, '017E007F'),
(142, '017E008E'),
(65534, '017EFFFE'),
(65535, '017EFFFF'),
(65536, '017F0000000000010000'),
(8589934591, '017F00000001FFFFFFFF'),
(2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'),
])
def test_serialization_length(self, input, expected):
h = websockets.FrameHeader(
opcode=websockets.OPCODE.TEXT,
payload_length=input,
)
assert bytes(h) == codecs.decode(expected, 'hex')
def test_serialization_too_large(self):
h = websockets.FrameHeader(
payload_length=2 ** 64 + 1,
)
with pytest.raises(ValueError):
bytes(h)
@pytest.mark.parametrize("input,expected", [
('0100', 0),
('017D', 125),
('017E007E', 126),
('017E007F', 127),
('017E008E', 142),
('017EFFFE', 65534),
('017EFFFF', 65535),
('017F0000000000010000', 65536),
('017F00000001FFFFFFFF', 8589934591),
('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1),
])
def test_deserialization_length(self, input, expected):
h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
assert h.payload_length == expected
@pytest.mark.parametrize("input,expected", [
('0100', (False, None)),
('018000000000', (True, '00000000')),
('018012345678', (True, '12345678')),
])
def test_deserialization_masking(self, input, expected):
h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
assert h.mask == expected[0]
if h.mask:
assert h.masking_key == codecs.decode(expected[1], 'hex')
def test_equality(self):
h = websockets.FrameHeader(mask=True, masking_key=b'1234')
h2 = websockets.FrameHeader(mask=True, masking_key=b'1234')
assert h == h2
h = websockets.FrameHeader(fin=True)
h2 = websockets.FrameHeader(fin=False)
assert h != h2
assert h != 'foobar'
def test_roundtrip(self):
def round(*args, **kwargs):
h = websockets.FrameHeader(*args, **kwargs)
h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h)))
assert h == h2
round()
round(fin=True)
round(rsv1=True)
round(rsv2=True)
round(rsv3=True)
round(payload_length=1)
round(payload_length=100)
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
round(masking_key=b"test")
def test_human_readable(self):
f = websockets.FrameHeader(
masking_key=b"test",
fin=True,
payload_length=10
)
assert repr(f)
f = websockets.FrameHeader()
assert repr(f)
def test_funky(self):
f = websockets.FrameHeader(masking_key=b"test", mask=False)
raw = bytes(f)
f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
assert not f2.mask
def test_violations(self):
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
def test_automask(self):
f = websockets.FrameHeader(mask=True)
assert f.masking_key
f = websockets.FrameHeader(masking_key=b"foob")
assert f.mask
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
assert not f.mask
assert f.masking_key
class TestFrame(object):
def test_equality(self):
f = websockets.Frame(payload=b'1234')
f2 = websockets.Frame(payload=b'1234')
assert f == f2
assert f != b'1234'
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
raw = bytes(f)
f2 = websockets.Frame.from_file(tutils.treader(raw))
assert f == f2
round(b"test")
round(b"test", fin=1)
round(b"test", rsv1=1)
round(b"test", opcode=websockets.OPCODE.PING)
round(b"test", masking_key=b"test")
def test_human_readable(self):
f = websockets.Frame()
assert repr(f)
f = websockets.Frame(b"foobar")
assert "foobar" in repr(f)
@pytest.mark.parametrize("masked", [True, False])
@pytest.mark.parametrize("length", [100, 50000, 150000])
def test_serialization_bijection(self, masked, length):
frame = websockets.Frame(
os.urandom(length),
fin=True,
opcode=websockets.OPCODE.TEXT,
mask=int(masked),
masking_key=(os.urandom(4) if masked else None)
)
serialized = bytes(frame)
assert frame == websockets.Frame.from_bytes(serialized)

View File

@ -0,0 +1,23 @@
import codecs
import pytest
from netlib import websockets
class TestMasker(object):
@pytest.mark.parametrize("input,expected", [
([b"a"], '00'),
([b"four"], '070d1616'),
([b"fourf"], '070d161607'),
([b"fourfive"], '070d1616070b1501'),
([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'),
([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa
])
def test_masker(self, input, expected):
m = websockets.Masker(b"abcd")
data = b"".join([m(t) for t in input])
assert data == codecs.decode(expected, 'hex')
data = websockets.Masker(b"abcd")(data)
assert data == b"".join(input)

View File

@ -0,0 +1,105 @@
import pytest
from netlib import http
from netlib import websockets
class TestUtils(object):
def test_client_handshake_headers(self):
h = websockets.client_handshake_headers(version='42')
assert h['sec-websocket-version'] == '42'
h = websockets.client_handshake_headers(key='some-key')
assert h['sec-websocket-key'] == 'some-key'
h = websockets.client_handshake_headers(protocol='foobar')
assert h['sec-websocket-protocol'] == 'foobar'
h = websockets.client_handshake_headers(extensions='foo; bar')
assert h['sec-websocket-extensions'] == 'foo; bar'
def test_server_handshake_headers(self):
h = websockets.server_handshake_headers('some-key')
assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
assert 'sec-websocket-protocol' not in h
assert 'sec-websocket-extensions' not in h
h = websockets.server_handshake_headers('some-key', 'foobar', 'foo; bar')
assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
assert h['sec-websocket-protocol'] == 'foobar'
assert h['sec-websocket-extensions'] == 'foo; bar'
@pytest.mark.parametrize("input,expected", [
([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], True),
([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'foobar')], True),
([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-KeY', b'foobar')], True),
([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-AccePt', b'foobar')], True),
([(b'connection', b'foo'), (b'upgrade', b'bar'), (b'sec-websocket-key', b'foobar')], False),
([(b'connection', b'upgrade'), (b'upgrade', b'websocket')], False),
([(b'connection', b'upgrade'), (b'sec-websocket-key', b'foobar')], False),
([(b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], False),
([], False),
])
def test_check_handshake(self, input, expected):
h = http.Headers(input)
assert websockets.check_handshake(h) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-version', b'13')], True),
([(b'Sec-WebSockeT-VerSion', b'13')], True),
([(b'sec-websocket-version', b'9')], False),
([(b'sec-websocket-version', b'42')], False),
([(b'sec-websocket-version', b'')], False),
([], False),
])
def test_check_client_version(self, input, expected):
h = http.Headers(input)
assert websockets.check_client_version(h) == expected
@pytest.mark.parametrize("input,expected", [
('foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
(b'foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
])
def test_create_server_nonce(self, input, expected):
assert websockets.create_server_nonce(input) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-extensions', b'foo; bar')], 'foo; bar'),
([(b'Sec-WebSockeT-ExteNsionS', b'foo; bar')], 'foo; bar'),
([(b'sec-websocket-extensions', b'')], ''),
([], None),
])
def test_get_extensions(self, input, expected):
h = http.Headers(input)
assert websockets.get_extensions(h) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-protocol', b'foobar')], 'foobar'),
([(b'Sec-WebSockeT-ProTocoL', b'foobar')], 'foobar'),
([(b'sec-websocket-protocol', b'')], ''),
([], None),
])
def test_get_protocol(self, input, expected):
h = http.Headers(input)
assert websockets.get_protocol(h) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-key', b'foobar')], 'foobar'),
([(b'Sec-WebSockeT-KeY', b'foobar')], 'foobar'),
([(b'sec-websocket-key', b'')], ''),
([], None),
])
def test_get_client_key(self, input, expected):
h = http.Headers(input)
assert websockets.get_client_key(h) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-accept', b'foobar')], 'foobar'),
([(b'Sec-WebSockeT-AccepT', b'foobar')], 'foobar'),
([(b'sec-websocket-accept', b'')], ''),
([], None),
])
def test_get_server_accept(self, input, expected):
h = http.Headers(input)
assert websockets.get_server_accept(h) == expected

View File

@ -1,269 +0,0 @@
import os
from netlib.http.http1 import read_response, read_request
from netlib import tcp
from netlib import tutils
from netlib import websockets
from netlib.http import status_codes
from netlib.tutils import treq
from netlib import exceptions
from .. import tservers
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(
connection, address, server
)
self.protocol = websockets.WebsocketsProtocol()
self.handshake_done = False
def handle(self):
while True:
if not self.handshake_done:
self.handshake()
else:
self.read_next_message()
def read_next_message(self):
frame = websockets.Frame.from_file(self.rfile)
self.on_message(frame.payload)
def send_message(self, message):
frame = websockets.Frame.default(message, from_client=False)
frame.to_file(self.wfile)
def handshake(self):
req = read_request(self.rfile)
key = self.protocol.check_client_handshake(req.headers)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble.encode() + b"\r\n")
headers = self.protocol.server_handshake_headers(key)
self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
self.handshake_done = True
def on_message(self, message):
if message is not None:
self.send_message(message)
class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
self.protocol = websockets.WebsocketsProtocol()
self.client_nonce = None
def connect(self):
super(WebSocketsClient, self).connect()
preamble = b'GET / HTTP/1.1'
self.wfile.write(preamble + b"\r\n")
headers = self.protocol.client_handshake_headers()
self.client_nonce = headers["sec-websocket-key"].encode("ascii")
self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
resp = read_response(self.rfile, treq(method=b"GET"))
server_nonce = self.protocol.check_server_handshake(resp.headers)
if not server_nonce == self.protocol.create_server_nonce(self.client_nonce):
self.close()
def read_next_message(self):
return websockets.Frame.from_file(self.rfile).payload
def send_message(self, message):
frame = websockets.Frame.default(message, from_client=True)
frame.to_file(self.wfile)
class TestWebSockets(tservers.ServerTestBase):
handler = WebSocketsEchoHandler
def __init__(self):
self.protocol = websockets.WebsocketsProtocol()
def random_bytes(self, n=100):
return os.urandom(n)
def echo(self, msg):
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message(msg)
response = client.read_next_message()
assert response == msg
def test_simple_echo(self):
self.echo(b"hello I'm the client")
def test_frame_sizes(self):
# length can fit in the the 7 bit payload length
small_msg = self.random_bytes(100)
# 50kb, sligthly larger than can fit in a 7 bit int
medium_msg = self.random_bytes(50000)
# 150kb, slightly larger than can fit in a 16 bit int
large_msg = self.random_bytes(150000)
self.echo(small_msg)
self.echo(medium_msg)
self.echo(large_msg)
def test_default_builder(self):
"""
default builder should always generate valid frames
"""
msg = self.random_bytes()
assert websockets.Frame.default(msg, from_client=True)
assert websockets.Frame.default(msg, from_client=False)
def test_serialization_bijection(self):
"""
Ensure that various frame types can be serialized/deserialized back
and forth between to_bytes() and from_bytes()
"""
for is_client in [True, False]:
for num_bytes in [100, 50000, 150000]:
frame = websockets.Frame.default(
self.random_bytes(num_bytes), is_client
)
frame2 = websockets.Frame.from_bytes(
frame.to_bytes()
)
assert frame == frame2
bytes = b'\x81\x03cba'
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
def test_check_server_handshake(self):
headers = self.protocol.server_handshake_headers("key")
assert self.protocol.check_server_handshake(headers)
headers["Upgrade"] = "not_websocket"
assert not self.protocol.check_server_handshake(headers)
def test_check_client_handshake(self):
headers = self.protocol.client_handshake_headers("key")
assert self.protocol.check_client_handshake(headers) == "key"
headers["Upgrade"] = "not_websocket"
assert not self.protocol.check_client_handshake(headers)
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
client_hs = read_request(self.rfile)
self.protocol.check_client_handshake(client_hs.headers)
preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble.encode())
headers = self.protocol.server_handshake_headers(b"malformed key")
self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
self.handshake_done = True
class TestBadHandshake(tservers.ServerTestBase):
"""
Ensure that the client disconnects if the server handshake is malformed
"""
handler = BadHandshakeHandler
def test(self):
with tutils.raises(exceptions.TcpDisconnect):
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message(b"hello")
class TestFrameHeader:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.FrameHeader(*args, **kwargs)
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
assert f == f2
round()
round(fin=1)
round(rsv1=1)
round(rsv2=1)
round(rsv3=1)
round(payload_length=1)
round(payload_length=100)
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
round(masking_key=b"test")
def test_human_readable(self):
f = websockets.FrameHeader(
masking_key=b"test",
fin=True,
payload_length=10
)
assert repr(f)
f = websockets.FrameHeader()
assert repr(f)
def test_funky(self):
f = websockets.FrameHeader(masking_key=b"test", mask=False)
raw = bytes(f)
f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
assert not f2.mask
def test_violations(self):
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
def test_automask(self):
f = websockets.FrameHeader(mask=True)
assert f.masking_key
f = websockets.FrameHeader(masking_key=b"foob")
assert f.mask
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
assert not f.mask
assert f.masking_key
class TestFrame:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
raw = bytes(f)
f2 = websockets.Frame.from_file(tutils.treader(raw))
assert f == f2
round(b"test")
round(b"test", fin=1)
round(b"test", rsv1=1)
round(b"test", opcode=websockets.OPCODE.PING)
round(b"test", masking_key=b"test")
def test_human_readable(self):
f = websockets.Frame()
assert repr(f)
def test_masker():
tests = [
[b"a"],
[b"four"],
[b"fourf"],
[b"fourfive"],
[b"a", b"aasdfasdfa", b"asdf"],
[b"a" * 50, b"aasdfasdfa", b"asdf"],
]
for i in tests:
m = websockets.Masker(b"abcd")
data = b"".join([m(t) for t in i])
data2 = websockets.Masker(b"abcd")(data)
assert data2 == b"".join(i)