websockets: refactor implementation and add tests

This commit is contained in:
Thomas Kriechbaumer 2016-08-18 17:31:43 +02:00
parent 281d779fa3
commit d12515f84b
9 changed files with 497 additions and 471 deletions

View File

@ -1,11 +1,37 @@
from __future__ import absolute_import, print_function, division
from .frame import FrameHeader, Frame, OPCODE
from .protocol import Masker, WebsocketsProtocol
from .frame import FrameHeader
from .frame import Frame
from .frame import OPCODE
from .frame import CLOSE_REASON
from .masker import Masker
from .utils import MAGIC
from .utils import VERSION
from .utils import client_handshake_headers
from .utils import server_handshake_headers
from .utils import check_handshake
from .utils import check_client_version
from .utils import create_server_nonce
from .utils import get_extensions
from .utils import get_protocol
from .utils import get_client_key
from .utils import get_server_accept
__all__ = [
"FrameHeader",
"Frame",
"Masker",
"WebsocketsProtocol",
"OPCODE",
"CLOSE_REASON",
"Masker",
"MAGIC",
"VERSION",
"client_handshake_headers",
"server_handshake_headers",
"check_handshake",
"check_client_version",
"create_server_nonce",
"get_extensions",
"get_protocol",
"get_client_key",
"get_server_accept",
]

View File

@ -2,7 +2,6 @@ from __future__ import absolute_import
import os
import struct
import io
import warnings
import six
@ -10,7 +9,7 @@ from netlib import tcp
from netlib import strutils
from netlib import utils
from netlib import human
from netlib.websockets import protocol
from .masker import Masker
MAX_16_BIT_INT = (1 << 16)
@ -18,6 +17,7 @@ MAX_64_BIT_INT = (1 << 64)
DEFAULT = object()
# RFC 6455, Section 5.2 - Base Framing Protocol
OPCODE = utils.BiDi(
CONTINUE=0x00,
TEXT=0x01,
@ -27,6 +27,23 @@ OPCODE = utils.BiDi(
PONG=0x0a
)
# RFC 6455, Section 7.4.1 - Defined Status Codes
CLOSE_REASON = utils.BiDi(
NORMAL_CLOSURE=1000,
GOING_AWAY=1001,
PROTOCOL_ERROR=1002,
UNSUPPORTED_DATA=1003,
RESERVED=1004,
RESERVED_NO_STATUS=1005,
RESERVED_ABNORMAL_CLOSURE=1006,
INVALID_PAYLOAD_DATA=1007,
POLICY_VIOLATION=1008,
MESSAGE_TOO_BIG=1009,
MANDATORY_EXTENSION=1010,
INTERNAL_ERROR=1011,
RESERVED_TLS_HANDHSAKE_FAILED=1015,
)
class FrameHeader(object):
@ -103,10 +120,6 @@ class FrameHeader(object):
vals.append(" %s" % human.pretty_size(self.payload_length))
return "".join(vals)
def human_readable(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return repr(self)
def __bytes__(self):
first_byte = utils.setbit(0, 7, self.fin)
first_byte = utils.setbit(first_byte, 6, self.rsv1)
@ -128,6 +141,9 @@ class FrameHeader(object):
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
b += struct.pack('!Q', self.payload_length)
else:
raise ValueError("Payload length exceeds 64bit integer")
if self.masking_key:
b += self.masking_key
return b
@ -135,10 +151,6 @@ class FrameHeader(object):
if six.PY2:
__str__ = __bytes__
def to_bytes(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return bytes(self)
@classmethod
def from_file(cls, fp):
"""
@ -151,19 +163,17 @@ class FrameHeader(object):
rsv1 = utils.getbit(first_byte, 6)
rsv2 = utils.getbit(first_byte, 5)
rsv3 = utils.getbit(first_byte, 4)
# grab right-most 4 bits
opcode = first_byte & 15
opcode = first_byte & 0xF
mask_bit = utils.getbit(second_byte, 7)
# grab the next 7 bits
length_code = second_byte & 127
length_code = second_byte & 0x7F
# payload_lengthy > 125 indicates you need to read more bytes
# payload_length > 125 indicates you need to read more bytes
# to get the actual payload length
if length_code <= 125:
payload_length = length_code
elif length_code == 126:
payload_length, = struct.unpack("!H", fp.safe_read(2))
elif length_code == 127:
else: # length_code == 127:
payload_length, = struct.unpack("!Q", fp.safe_read(8))
# masking key only present if mask bit set
@ -191,11 +201,10 @@ class FrameHeader(object):
class Frame(object):
"""
Represents one websockets frame.
Constructor takes human readable forms of the frame components
from_bytes() is also avaliable.
Represents a single WebSockets frame.
Constructor takes human readable forms of the frame components.
from_bytes() reads from a file-like object to create a new Frame.
WebSockets Frame as defined in RFC6455
@ -223,27 +232,6 @@ class Frame(object):
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
self.header = FrameHeader(**kwargs)
@classmethod
def default(cls, message, from_client=False):
"""
Construct a basic websocket frame from some default values.
Creates a non-fragmented text frame.
"""
if from_client:
mask_bit = 1
masking_key = os.urandom(4)
else:
mask_bit = 0
masking_key = None
return cls(
message,
fin=1, # final frame
opcode=OPCODE.TEXT, # text
mask=mask_bit,
masking_key=masking_key,
)
@classmethod
def from_bytes(cls, bytestring):
"""
@ -258,17 +246,13 @@ class Frame(object):
ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload)
return ret
def human_readable(self):
warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning)
return repr(self)
def __bytes__(self):
"""
Serialize the frame to wire format. Returns a string.
"""
b = bytes(self.header)
if self.header.masking_key:
b += protocol.Masker(self.header.masking_key)(self.payload)
b += Masker(self.header.masking_key)(self.payload)
else:
b += self.payload
return b
@ -276,15 +260,6 @@ class Frame(object):
if six.PY2:
__str__ = __bytes__
def to_bytes(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return bytes(self)
def to_file(self, writer):
warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning)
writer.write(bytes(self))
writer.flush()
@classmethod
def from_file(cls, fp):
"""
@ -297,20 +272,11 @@ class Frame(object):
payload = fp.safe_read(header.payload_length)
if header.mask == 1 and header.masking_key:
payload = protocol.Masker(header.masking_key)(payload)
payload = Masker(header.masking_key)(payload)
return cls(
payload,
fin=header.fin,
opcode=header.opcode,
mask=header.mask,
payload_length=header.payload_length,
masking_key=header.masking_key,
rsv1=header.rsv1,
rsv2=header.rsv2,
rsv3=header.rsv3,
length_code=header.length_code
)
frame = cls(payload)
frame.header = header
return frame
def __eq__(self, other):
if isinstance(other, Frame):

View File

@ -0,0 +1,33 @@
from __future__ import absolute_import
import six
class Masker(object):
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns.
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
def __init__(self, key):
self.key = key
self.offset = 0
def mask(self, offset, data):
result = bytearray(data)
for i in range(len(data)):
if six.PY2:
result[i] ^= ord(self.key[offset % 4])
else:
result[i] ^= self.key[offset % 4]
offset += 1
result = bytes(result)
return result
def __call__(self, data):
ret = self.mask(self.offset, data)
self.offset += len(ret)
return ret

View File

@ -1,112 +0,0 @@
"""
Colleciton of utility functions that implement small portions of the RFC6455
WebSockets Protocol Useful for building WebSocket clients and servers.
Emphassis is on readabilty, simplicity and modularity, not performance or
completeness
This is a work in progress and does not yet contain all the utilites need to
create fully complient client/servers #
Spec: https://tools.ietf.org/html/rfc6455
The magic sha that websocket servers must know to prove they understand
RFC6455
"""
from __future__ import absolute_import
import base64
import hashlib
import os
import six
from netlib import http, strutils
websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
class Masker(object):
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
def __init__(self, key):
self.key = key
self.offset = 0
def mask(self, offset, data):
result = bytearray(data)
if six.PY2:
for i in range(len(data)):
result[i] ^= ord(self.key[offset % 4])
offset += 1
result = str(result)
else:
for i in range(len(data)):
result[i] ^= self.key[offset % 4]
offset += 1
result = bytes(result)
return result
def __call__(self, data):
ret = self.mask(self.offset, data)
self.offset += len(ret)
return ret
class WebsocketsProtocol(object):
def __init__(self):
pass
@classmethod
def client_handshake_headers(self, key=None, version=VERSION):
"""
Create the headers for a valid HTTP upgrade request. If Key is not
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of http.Headers
"""
if not key:
key = base64.b64encode(os.urandom(16)).decode('ascii')
return http.Headers(
sec_websocket_key=key,
sec_websocket_version=version,
connection="Upgrade",
upgrade="websocket",
)
@classmethod
def server_handshake_headers(self, key):
"""
The server response is a valid HTTP 101 response.
"""
return http.Headers(
sec_websocket_accept=self.create_server_nonce(key),
connection="Upgrade",
upgrade="websocket"
)
@classmethod
def check_client_handshake(self, headers):
if headers.get("upgrade") != "websocket":
return
return headers.get("sec-websocket-key")
@classmethod
def check_server_handshake(self, headers):
if headers.get("upgrade") != "websocket":
return
return headers.get("sec-websocket-accept")
@classmethod
def create_server_nonce(self, client_nonce):
return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + websockets_magic).digest())

View File

@ -0,0 +1,90 @@
"""
Collection of WebSockets Protocol utility functions (RFC6455)
Spec: https://tools.ietf.org/html/rfc6455
"""
from __future__ import absolute_import
import base64
import hashlib
import os
from netlib import http, strutils
MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):
"""
Create the headers for a valid HTTP upgrade request. If Key is not
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of http.Headers
"""
if version is None:
version = VERSION
if key is None:
key = base64.b64encode(os.urandom(16)).decode('ascii')
h = http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_version=version,
sec_websocket_key=key,
)
if protocol is not None:
h['sec-websocket-protocol'] = protocol
if extensions is not None:
h['sec-websocket-extensions'] = extensions
return h
def server_handshake_headers(client_key, protocol=None, extensions=None):
"""
The server response is a valid HTTP 101 response.
Returns an instance of http.Headers
"""
h = http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_accept=create_server_nonce(client_key),
)
if protocol is not None:
h['sec-websocket-protocol'] = protocol
if extensions is not None:
h['sec-websocket-extensions'] = extensions
return h
def check_handshake(headers):
return (
"upgrade" in headers.get("connection", "").lower() and
headers.get("upgrade", "").lower() == "websocket" and
(headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None)
)
def create_server_nonce(client_nonce):
return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest())
def check_client_version(headers):
return headers.get("sec-websocket-version", "") == VERSION
def get_extensions(headers):
return headers.get("sec-websocket-extensions", None)
def get_protocol(headers):
return headers.get("sec-websocket-protocol", None)
def get_client_key(headers):
return headers.get("sec-websocket-key", None)
def get_server_accept(headers):
return headers.get("sec-websocket-accept", None)

View File

@ -0,0 +1,164 @@
import os
import codecs
import pytest
from netlib import websockets
from netlib import tutils
class TestFrameHeader(object):
@pytest.mark.parametrize("input,expected", [
(0, '0100'),
(125, '017D'),
(126, '017E007E'),
(127, '017E007F'),
(142, '017E008E'),
(65534, '017EFFFE'),
(65535, '017EFFFF'),
(65536, '017F0000000000010000'),
(8589934591, '017F00000001FFFFFFFF'),
(2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'),
])
def test_serialization_length(self, input, expected):
h = websockets.FrameHeader(
opcode=websockets.OPCODE.TEXT,
payload_length=input,
)
assert bytes(h) == codecs.decode(expected, 'hex')
def test_serialization_too_large(self):
h = websockets.FrameHeader(
payload_length=2 ** 64 + 1,
)
with pytest.raises(ValueError):
bytes(h)
@pytest.mark.parametrize("input,expected", [
('0100', 0),
('017D', 125),
('017E007E', 126),
('017E007F', 127),
('017E008E', 142),
('017EFFFE', 65534),
('017EFFFF', 65535),
('017F0000000000010000', 65536),
('017F00000001FFFFFFFF', 8589934591),
('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1),
])
def test_deserialization_length(self, input, expected):
h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
assert h.payload_length == expected
@pytest.mark.parametrize("input,expected", [
('0100', (False, None)),
('018000000000', (True, '00000000')),
('018012345678', (True, '12345678')),
])
def test_deserialization_masking(self, input, expected):
h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
assert h.mask == expected[0]
if h.mask:
assert h.masking_key == codecs.decode(expected[1], 'hex')
def test_equality(self):
h = websockets.FrameHeader(mask=True, masking_key=b'1234')
h2 = websockets.FrameHeader(mask=True, masking_key=b'1234')
assert h == h2
h = websockets.FrameHeader(fin=True)
h2 = websockets.FrameHeader(fin=False)
assert h != h2
assert h != 'foobar'
def test_roundtrip(self):
def round(*args, **kwargs):
h = websockets.FrameHeader(*args, **kwargs)
h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h)))
assert h == h2
round()
round(fin=True)
round(rsv1=True)
round(rsv2=True)
round(rsv3=True)
round(payload_length=1)
round(payload_length=100)
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
round(masking_key=b"test")
def test_human_readable(self):
f = websockets.FrameHeader(
masking_key=b"test",
fin=True,
payload_length=10
)
assert repr(f)
f = websockets.FrameHeader()
assert repr(f)
def test_funky(self):
f = websockets.FrameHeader(masking_key=b"test", mask=False)
raw = bytes(f)
f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
assert not f2.mask
def test_violations(self):
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
def test_automask(self):
f = websockets.FrameHeader(mask=True)
assert f.masking_key
f = websockets.FrameHeader(masking_key=b"foob")
assert f.mask
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
assert not f.mask
assert f.masking_key
class TestFrame(object):
def test_equality(self):
f = websockets.Frame(payload=b'1234')
f2 = websockets.Frame(payload=b'1234')
assert f == f2
assert f != b'1234'
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
raw = bytes(f)
f2 = websockets.Frame.from_file(tutils.treader(raw))
assert f == f2
round(b"test")
round(b"test", fin=1)
round(b"test", rsv1=1)
round(b"test", opcode=websockets.OPCODE.PING)
round(b"test", masking_key=b"test")
def test_human_readable(self):
f = websockets.Frame()
assert repr(f)
f = websockets.Frame(b"foobar")
assert "foobar" in repr(f)
@pytest.mark.parametrize("masked", [True, False])
@pytest.mark.parametrize("length", [100, 50000, 150000])
def test_serialization_bijection(self, masked, length):
frame = websockets.Frame(
os.urandom(length),
fin=True,
opcode=websockets.OPCODE.TEXT,
mask=int(masked),
masking_key=(os.urandom(4) if masked else None)
)
serialized = bytes(frame)
assert frame == websockets.Frame.from_bytes(serialized)

View File

@ -0,0 +1,23 @@
import codecs
import pytest
from netlib import websockets
class TestMasker(object):
@pytest.mark.parametrize("input,expected", [
([b"a"], '00'),
([b"four"], '070d1616'),
([b"fourf"], '070d161607'),
([b"fourfive"], '070d1616070b1501'),
([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'),
([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa
])
def test_masker(self, input, expected):
m = websockets.Masker(b"abcd")
data = b"".join([m(t) for t in input])
assert data == codecs.decode(expected, 'hex')
data = websockets.Masker(b"abcd")(data)
assert data == b"".join(input)

View File

@ -0,0 +1,105 @@
import pytest
from netlib import http
from netlib import websockets
class TestUtils(object):
def test_client_handshake_headers(self):
h = websockets.client_handshake_headers(version='42')
assert h['sec-websocket-version'] == '42'
h = websockets.client_handshake_headers(key='some-key')
assert h['sec-websocket-key'] == 'some-key'
h = websockets.client_handshake_headers(protocol='foobar')
assert h['sec-websocket-protocol'] == 'foobar'
h = websockets.client_handshake_headers(extensions='foo; bar')
assert h['sec-websocket-extensions'] == 'foo; bar'
def test_server_handshake_headers(self):
h = websockets.server_handshake_headers('some-key')
assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
assert 'sec-websocket-protocol' not in h
assert 'sec-websocket-extensions' not in h
h = websockets.server_handshake_headers('some-key', 'foobar', 'foo; bar')
assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
assert h['sec-websocket-protocol'] == 'foobar'
assert h['sec-websocket-extensions'] == 'foo; bar'
@pytest.mark.parametrize("input,expected", [
([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], True),
([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'foobar')], True),
([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-KeY', b'foobar')], True),
([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-AccePt', b'foobar')], True),
([(b'connection', b'foo'), (b'upgrade', b'bar'), (b'sec-websocket-key', b'foobar')], False),
([(b'connection', b'upgrade'), (b'upgrade', b'websocket')], False),
([(b'connection', b'upgrade'), (b'sec-websocket-key', b'foobar')], False),
([(b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], False),
([], False),
])
def test_check_handshake(self, input, expected):
h = http.Headers(input)
assert websockets.check_handshake(h) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-version', b'13')], True),
([(b'Sec-WebSockeT-VerSion', b'13')], True),
([(b'sec-websocket-version', b'9')], False),
([(b'sec-websocket-version', b'42')], False),
([(b'sec-websocket-version', b'')], False),
([], False),
])
def test_check_client_version(self, input, expected):
h = http.Headers(input)
assert websockets.check_client_version(h) == expected
@pytest.mark.parametrize("input,expected", [
('foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
(b'foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
])
def test_create_server_nonce(self, input, expected):
assert websockets.create_server_nonce(input) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-extensions', b'foo; bar')], 'foo; bar'),
([(b'Sec-WebSockeT-ExteNsionS', b'foo; bar')], 'foo; bar'),
([(b'sec-websocket-extensions', b'')], ''),
([], None),
])
def test_get_extensions(self, input, expected):
h = http.Headers(input)
assert websockets.get_extensions(h) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-protocol', b'foobar')], 'foobar'),
([(b'Sec-WebSockeT-ProTocoL', b'foobar')], 'foobar'),
([(b'sec-websocket-protocol', b'')], ''),
([], None),
])
def test_get_protocol(self, input, expected):
h = http.Headers(input)
assert websockets.get_protocol(h) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-key', b'foobar')], 'foobar'),
([(b'Sec-WebSockeT-KeY', b'foobar')], 'foobar'),
([(b'sec-websocket-key', b'')], ''),
([], None),
])
def test_get_client_key(self, input, expected):
h = http.Headers(input)
assert websockets.get_client_key(h) == expected
@pytest.mark.parametrize("input,expected", [
([(b'sec-websocket-accept', b'foobar')], 'foobar'),
([(b'Sec-WebSockeT-AccepT', b'foobar')], 'foobar'),
([(b'sec-websocket-accept', b'')], ''),
([], None),
])
def test_get_server_accept(self, input, expected):
h = http.Headers(input)
assert websockets.get_server_accept(h) == expected

View File

@ -1,269 +0,0 @@
import os
from netlib.http.http1 import read_response, read_request
from netlib import tcp
from netlib import tutils
from netlib import websockets
from netlib.http import status_codes
from netlib.tutils import treq
from netlib import exceptions
from .. import tservers
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(
connection, address, server
)
self.protocol = websockets.WebsocketsProtocol()
self.handshake_done = False
def handle(self):
while True:
if not self.handshake_done:
self.handshake()
else:
self.read_next_message()
def read_next_message(self):
frame = websockets.Frame.from_file(self.rfile)
self.on_message(frame.payload)
def send_message(self, message):
frame = websockets.Frame.default(message, from_client=False)
frame.to_file(self.wfile)
def handshake(self):
req = read_request(self.rfile)
key = self.protocol.check_client_handshake(req.headers)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble.encode() + b"\r\n")
headers = self.protocol.server_handshake_headers(key)
self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
self.handshake_done = True
def on_message(self, message):
if message is not None:
self.send_message(message)
class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
self.protocol = websockets.WebsocketsProtocol()
self.client_nonce = None
def connect(self):
super(WebSocketsClient, self).connect()
preamble = b'GET / HTTP/1.1'
self.wfile.write(preamble + b"\r\n")
headers = self.protocol.client_handshake_headers()
self.client_nonce = headers["sec-websocket-key"].encode("ascii")
self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
resp = read_response(self.rfile, treq(method=b"GET"))
server_nonce = self.protocol.check_server_handshake(resp.headers)
if not server_nonce == self.protocol.create_server_nonce(self.client_nonce):
self.close()
def read_next_message(self):
return websockets.Frame.from_file(self.rfile).payload
def send_message(self, message):
frame = websockets.Frame.default(message, from_client=True)
frame.to_file(self.wfile)
class TestWebSockets(tservers.ServerTestBase):
handler = WebSocketsEchoHandler
def __init__(self):
self.protocol = websockets.WebsocketsProtocol()
def random_bytes(self, n=100):
return os.urandom(n)
def echo(self, msg):
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message(msg)
response = client.read_next_message()
assert response == msg
def test_simple_echo(self):
self.echo(b"hello I'm the client")
def test_frame_sizes(self):
# length can fit in the the 7 bit payload length
small_msg = self.random_bytes(100)
# 50kb, sligthly larger than can fit in a 7 bit int
medium_msg = self.random_bytes(50000)
# 150kb, slightly larger than can fit in a 16 bit int
large_msg = self.random_bytes(150000)
self.echo(small_msg)
self.echo(medium_msg)
self.echo(large_msg)
def test_default_builder(self):
"""
default builder should always generate valid frames
"""
msg = self.random_bytes()
assert websockets.Frame.default(msg, from_client=True)
assert websockets.Frame.default(msg, from_client=False)
def test_serialization_bijection(self):
"""
Ensure that various frame types can be serialized/deserialized back
and forth between to_bytes() and from_bytes()
"""
for is_client in [True, False]:
for num_bytes in [100, 50000, 150000]:
frame = websockets.Frame.default(
self.random_bytes(num_bytes), is_client
)
frame2 = websockets.Frame.from_bytes(
frame.to_bytes()
)
assert frame == frame2
bytes = b'\x81\x03cba'
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
def test_check_server_handshake(self):
headers = self.protocol.server_handshake_headers("key")
assert self.protocol.check_server_handshake(headers)
headers["Upgrade"] = "not_websocket"
assert not self.protocol.check_server_handshake(headers)
def test_check_client_handshake(self):
headers = self.protocol.client_handshake_headers("key")
assert self.protocol.check_client_handshake(headers) == "key"
headers["Upgrade"] = "not_websocket"
assert not self.protocol.check_client_handshake(headers)
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
client_hs = read_request(self.rfile)
self.protocol.check_client_handshake(client_hs.headers)
preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble.encode())
headers = self.protocol.server_handshake_headers(b"malformed key")
self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
self.handshake_done = True
class TestBadHandshake(tservers.ServerTestBase):
"""
Ensure that the client disconnects if the server handshake is malformed
"""
handler = BadHandshakeHandler
def test(self):
with tutils.raises(exceptions.TcpDisconnect):
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message(b"hello")
class TestFrameHeader:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.FrameHeader(*args, **kwargs)
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
assert f == f2
round()
round(fin=1)
round(rsv1=1)
round(rsv2=1)
round(rsv3=1)
round(payload_length=1)
round(payload_length=100)
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
round(masking_key=b"test")
def test_human_readable(self):
f = websockets.FrameHeader(
masking_key=b"test",
fin=True,
payload_length=10
)
assert repr(f)
f = websockets.FrameHeader()
assert repr(f)
def test_funky(self):
f = websockets.FrameHeader(masking_key=b"test", mask=False)
raw = bytes(f)
f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
assert not f2.mask
def test_violations(self):
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
def test_automask(self):
f = websockets.FrameHeader(mask=True)
assert f.masking_key
f = websockets.FrameHeader(masking_key=b"foob")
assert f.mask
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
assert not f.mask
assert f.masking_key
class TestFrame:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
raw = bytes(f)
f2 = websockets.Frame.from_file(tutils.treader(raw))
assert f == f2
round(b"test")
round(b"test", fin=1)
round(b"test", rsv1=1)
round(b"test", opcode=websockets.OPCODE.PING)
round(b"test", masking_key=b"test")
def test_human_readable(self):
f = websockets.Frame()
assert repr(f)
def test_masker():
tests = [
[b"a"],
[b"four"],
[b"fourf"],
[b"fourfive"],
[b"a", b"aasdfasdfa", b"asdf"],
[b"a" * 50, b"aasdfasdfa", b"asdf"],
]
for i in tests:
m = websockets.Masker(b"abcd")
data = b"".join([m(t) for t in i])
data2 = websockets.Masker(b"abcd")(data)
assert data2 == b"".join(i)