Websockets: reorganise

- websockets.py to top-level
- implementations into test suite
This commit is contained in:
Aldo Cortesi 2015-04-20 09:38:09 +12:00
parent 08ba987a84
commit 74389ef04a
4 changed files with 90 additions and 98 deletions

View File

@ -8,7 +8,7 @@ import os
import struct import struct
import io import io
from .. import utils from . import utils
# Colleciton of utility functions that implement small portions of the RFC6455 # Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers. # WebSockets Protocol Useful for building WebSocket clients and servers.

View File

@ -1 +0,0 @@
from __future__ import (absolute_import, print_function, division)

View File

@ -1,80 +0,0 @@
from netlib import tcp
from base64 import b64encode
from StringIO import StringIO
from . import websockets as ws
import struct
import SocketServer
import os
# Simple websocket client and servers that are used to exercise the functionality in websockets.py
# These are *not* fully RFC6455 compliant
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(connection, address, server)
self.handshake_done = False
def handle(self):
while True:
if not self.handshake_done:
self.handshake()
else:
self.read_next_message()
def read_next_message(self):
decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload
self.on_message(decoded)
def send_message(self, message):
frame = ws.Frame.default(message, from_client = False)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
def handshake(self):
client_hs = ws.read_handshake(self.rfile.read, 1)
key = ws.process_handshake_from_client(client_hs)
response = ws.create_server_handshake(key)
self.wfile.write(response)
self.wfile.flush()
self.handshake_done = True
def on_message(self, message):
if message is not None:
self.send_message(message)
class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
self.version = "13"
self.client_nounce = ws.create_client_nounce()
self.resource = "/"
def connect(self):
super(WebSocketsClient, self).connect()
handshake = ws.create_client_handshake(
self.address.host,
self.address.port,
self.client_nounce,
self.version,
self.resource
)
self.wfile.write(handshake)
self.wfile.flush()
server_handshake = ws.read_handshake(self.rfile.read, 1)
server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce)
if not server_nounce == ws.create_server_nounce(self.client_nounce):
self.close()
def read_next_message(self):
return ws.Frame.from_byte_stream(self.rfile.read).payload
def send_message(self, message):
frame = ws.Frame.default(message, from_client = True)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()

View File

@ -1,19 +1,92 @@
from netlib import tcp from netlib import tcp
from netlib import test from netlib import test
from netlib.websockets import implementations as impl from netlib import websockets
from netlib.websockets import websockets as ws
import os import os
from nose.tools import raises from nose.tools import raises
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(
connection, address, server
)
self.handshake_done = False
def handle(self):
while True:
if not self.handshake_done:
self.handshake()
else:
self.read_next_message()
def read_next_message(self):
decoded = websockets.Frame.from_byte_stream(self.rfile.read).decoded_payload
self.on_message(decoded)
def send_message(self, message):
frame = websockets.Frame.default(message, from_client = False)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
def handshake(self):
client_hs = websockets.read_handshake(self.rfile.read, 1)
key = websockets.process_handshake_from_client(client_hs)
response = websockets.create_server_handshake(key)
self.wfile.write(response)
self.wfile.flush()
self.handshake_done = True
def on_message(self, message):
if message is not None:
self.send_message(message)
class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
self.version = "13"
self.client_nounce = websockets.create_client_nounce()
self.resource = "/"
def connect(self):
super(WebSocketsClient, self).connect()
handshake = websockets.create_client_handshake(
self.address.host,
self.address.port,
self.client_nounce,
self.version,
self.resource
)
self.wfile.write(handshake)
self.wfile.flush()
server_handshake = websockets.read_handshake(self.rfile.read, 1)
server_nounce = websockets.process_handshake_from_server(
server_handshake, self.client_nounce
)
if not server_nounce == websockets.create_server_nounce(self.client_nounce):
self.close()
def read_next_message(self):
return websockets.Frame.from_byte_stream(self.rfile.read).payload
def send_message(self, message):
frame = websockets.Frame.default(message, from_client = True)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
class TestWebSockets(test.ServerTestBase): class TestWebSockets(test.ServerTestBase):
handler = impl.WebSocketsEchoHandler handler = WebSocketsEchoHandler
def random_bytes(self, n = 100): def random_bytes(self, n = 100):
return os.urandom(n) return os.urandom(n)
def echo(self, msg): def echo(self, msg):
client = impl.WebSocketsClient(("127.0.0.1", self.port)) client = WebSocketsClient(("127.0.0.1", self.port))
client.connect() client.connect()
client.send_message(msg) client.send_message(msg)
response = client.read_next_message() response = client.read_next_message()
@ -39,10 +112,10 @@ class TestWebSockets(test.ServerTestBase):
default builder should always generate valid frames default builder should always generate valid frames
""" """
msg = self.random_bytes() msg = self.random_bytes()
client_frame = ws.Frame.default(msg, from_client = True) client_frame = websockets.Frame.default(msg, from_client = True)
assert client_frame.is_valid() assert client_frame.is_valid()
server_frame = ws.Frame.default(msg, from_client = False) server_frame = websockets.Frame.default(msg, from_client = False)
assert server_frame.is_valid() assert server_frame.is_valid()
def test_serialization_bijection(self): def test_serialization_bijection(self):
@ -52,26 +125,26 @@ class TestWebSockets(test.ServerTestBase):
""" """
for is_client in [True, False]: for is_client in [True, False]:
for num_bytes in [100, 50000, 150000]: for num_bytes in [100, 50000, 150000]:
frame = ws.Frame.default( frame = websockets.Frame.default(
self.random_bytes(num_bytes), is_client self.random_bytes(num_bytes), is_client
) )
assert frame == ws.Frame.from_bytes(frame.to_bytes()) assert frame == websockets.Frame.from_bytes(frame.to_bytes())
bytes = b'\x81\x11cba' bytes = b'\x81\x11cba'
assert ws.Frame.from_bytes(bytes).to_bytes() == bytes assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
@raises(ws.WebSocketFrameValidationException) @raises(websockets.WebSocketFrameValidationException)
def test_safe_to_bytes(self): def test_safe_to_bytes(self):
frame = ws.Frame.default(self.random_bytes(8)) frame = websockets.Frame.default(self.random_bytes(8))
frame.actual_payload_length = 1 # corrupt the frame frame.actual_payload_length = 1 # corrupt the frame
frame.safe_to_bytes() frame.safe_to_bytes()
class BadHandshakeHandler(impl.WebSocketsEchoHandler): class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self): def handshake(self):
client_hs = ws.read_handshake(self.rfile.read, 1) client_hs = websockets.read_handshake(self.rfile.read, 1)
ws.process_handshake_from_client(client_hs) websockets.process_handshake_from_client(client_hs)
response = ws.create_server_handshake("malformed_key") response = websockets.create_server_handshake("malformed_key")
self.wfile.write(response) self.wfile.write(response)
self.wfile.flush() self.wfile.flush()
self.handshake_done = True self.handshake_done = True
@ -85,6 +158,6 @@ class TestBadHandshake(test.ServerTestBase):
@raises(tcp.NetLibDisconnect) @raises(tcp.NetLibDisconnect)
def test(self): def test(self):
client = impl.WebSocketsClient(("127.0.0.1", self.port)) client = WebSocketsClient(("127.0.0.1", self.port))
client.connect() client.connect()
client.send_message("hello") client.send_message("hello")