Websockets: reorganise

- websockets.py to top-level
- implementations into test suite
This commit is contained in:
Aldo Cortesi 2015-04-20 09:38:09 +12:00
parent 08ba987a84
commit 74389ef04a
4 changed files with 90 additions and 98 deletions

View File

@ -8,7 +8,7 @@ import os
import struct
import io
from .. import utils
from . import utils
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.

View File

@ -1 +0,0 @@
from __future__ import (absolute_import, print_function, division)

View File

@ -1,80 +0,0 @@
from netlib import tcp
from base64 import b64encode
from StringIO import StringIO
from . import websockets as ws
import struct
import SocketServer
import os
# Simple websocket client and servers that are used to exercise the functionality in websockets.py
# These are *not* fully RFC6455 compliant
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(connection, address, server)
self.handshake_done = False
def handle(self):
while True:
if not self.handshake_done:
self.handshake()
else:
self.read_next_message()
def read_next_message(self):
decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload
self.on_message(decoded)
def send_message(self, message):
frame = ws.Frame.default(message, from_client = False)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
def handshake(self):
client_hs = ws.read_handshake(self.rfile.read, 1)
key = ws.process_handshake_from_client(client_hs)
response = ws.create_server_handshake(key)
self.wfile.write(response)
self.wfile.flush()
self.handshake_done = True
def on_message(self, message):
if message is not None:
self.send_message(message)
class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
self.version = "13"
self.client_nounce = ws.create_client_nounce()
self.resource = "/"
def connect(self):
super(WebSocketsClient, self).connect()
handshake = ws.create_client_handshake(
self.address.host,
self.address.port,
self.client_nounce,
self.version,
self.resource
)
self.wfile.write(handshake)
self.wfile.flush()
server_handshake = ws.read_handshake(self.rfile.read, 1)
server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce)
if not server_nounce == ws.create_server_nounce(self.client_nounce):
self.close()
def read_next_message(self):
return ws.Frame.from_byte_stream(self.rfile.read).payload
def send_message(self, message):
frame = ws.Frame.default(message, from_client = True)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()

View File

@ -1,19 +1,92 @@
from netlib import tcp
from netlib import test
from netlib.websockets import implementations as impl
from netlib.websockets import websockets as ws
from netlib import websockets
import os
from nose.tools import raises
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(
connection, address, server
)
self.handshake_done = False
def handle(self):
while True:
if not self.handshake_done:
self.handshake()
else:
self.read_next_message()
def read_next_message(self):
decoded = websockets.Frame.from_byte_stream(self.rfile.read).decoded_payload
self.on_message(decoded)
def send_message(self, message):
frame = websockets.Frame.default(message, from_client = False)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
def handshake(self):
client_hs = websockets.read_handshake(self.rfile.read, 1)
key = websockets.process_handshake_from_client(client_hs)
response = websockets.create_server_handshake(key)
self.wfile.write(response)
self.wfile.flush()
self.handshake_done = True
def on_message(self, message):
if message is not None:
self.send_message(message)
class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
self.version = "13"
self.client_nounce = websockets.create_client_nounce()
self.resource = "/"
def connect(self):
super(WebSocketsClient, self).connect()
handshake = websockets.create_client_handshake(
self.address.host,
self.address.port,
self.client_nounce,
self.version,
self.resource
)
self.wfile.write(handshake)
self.wfile.flush()
server_handshake = websockets.read_handshake(self.rfile.read, 1)
server_nounce = websockets.process_handshake_from_server(
server_handshake, self.client_nounce
)
if not server_nounce == websockets.create_server_nounce(self.client_nounce):
self.close()
def read_next_message(self):
return websockets.Frame.from_byte_stream(self.rfile.read).payload
def send_message(self, message):
frame = websockets.Frame.default(message, from_client = True)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
class TestWebSockets(test.ServerTestBase):
handler = impl.WebSocketsEchoHandler
handler = WebSocketsEchoHandler
def random_bytes(self, n = 100):
return os.urandom(n)
def echo(self, msg):
client = impl.WebSocketsClient(("127.0.0.1", self.port))
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message(msg)
response = client.read_next_message()
@ -39,10 +112,10 @@ class TestWebSockets(test.ServerTestBase):
default builder should always generate valid frames
"""
msg = self.random_bytes()
client_frame = ws.Frame.default(msg, from_client = True)
client_frame = websockets.Frame.default(msg, from_client = True)
assert client_frame.is_valid()
server_frame = ws.Frame.default(msg, from_client = False)
server_frame = websockets.Frame.default(msg, from_client = False)
assert server_frame.is_valid()
def test_serialization_bijection(self):
@ -52,26 +125,26 @@ class TestWebSockets(test.ServerTestBase):
"""
for is_client in [True, False]:
for num_bytes in [100, 50000, 150000]:
frame = ws.Frame.default(
frame = websockets.Frame.default(
self.random_bytes(num_bytes), is_client
)
assert frame == ws.Frame.from_bytes(frame.to_bytes())
assert frame == websockets.Frame.from_bytes(frame.to_bytes())
bytes = b'\x81\x11cba'
assert ws.Frame.from_bytes(bytes).to_bytes() == bytes
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
@raises(ws.WebSocketFrameValidationException)
@raises(websockets.WebSocketFrameValidationException)
def test_safe_to_bytes(self):
frame = ws.Frame.default(self.random_bytes(8))
frame = websockets.Frame.default(self.random_bytes(8))
frame.actual_payload_length = 1 # corrupt the frame
frame.safe_to_bytes()
class BadHandshakeHandler(impl.WebSocketsEchoHandler):
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
client_hs = ws.read_handshake(self.rfile.read, 1)
ws.process_handshake_from_client(client_hs)
response = ws.create_server_handshake("malformed_key")
client_hs = websockets.read_handshake(self.rfile.read, 1)
websockets.process_handshake_from_client(client_hs)
response = websockets.create_server_handshake("malformed_key")
self.wfile.write(response)
self.wfile.flush()
self.handshake_done = True
@ -85,6 +158,6 @@ class TestBadHandshake(test.ServerTestBase):
@raises(tcp.NetLibDisconnect)
def test(self):
client = impl.WebSocketsClient(("127.0.0.1", self.port))
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message("hello")