diff --git a/netlib/websockets/websockets.py b/netlib/websockets.py similarity index 99% rename from netlib/websockets/websockets.py rename to netlib/websockets.py index 86d98cafd..83e902385 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets.py @@ -8,7 +8,7 @@ import os import struct import io -from .. import utils +from . import utils # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py deleted file mode 100644 index 9b4faa337..000000000 --- a/netlib/websockets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from __future__ import (absolute_import, print_function, division) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py deleted file mode 100644 index 337c54964..000000000 --- a/netlib/websockets/implementations.py +++ /dev/null @@ -1,80 +0,0 @@ -from netlib import tcp -from base64 import b64encode -from StringIO import StringIO -from . import websockets as ws -import struct -import SocketServer -import os - -# Simple websocket client and servers that are used to exercise the functionality in websockets.py -# These are *not* fully RFC6455 compliant - -class WebSocketsEchoHandler(tcp.BaseHandler): - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__(connection, address, server) - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload - self.on_message(decoded) - - def send_message(self, message): - frame = ws.Frame.default(message, from_client = False) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() - - def handshake(self): - client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.process_handshake_from_client(client_hs) - response = ws.create_server_handshake(key) - self.wfile.write(response) - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.client_nounce = ws.create_client_nounce() - self.resource = "/" - - def connect(self): - super(WebSocketsClient, self).connect() - - handshake = ws.create_client_handshake( - self.address.host, - self.address.port, - self.client_nounce, - self.version, - self.resource - ) - - self.wfile.write(handshake) - self.wfile.flush() - - server_handshake = ws.read_handshake(self.rfile.read, 1) - - server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) - - if not server_nounce == ws.create_server_nounce(self.client_nounce): - self.close() - - def read_next_message(self): - return ws.Frame.from_byte_stream(self.rfile.read).payload - - def send_message(self, message): - frame = ws.Frame.default(message, from_client = True) - self.wfile.write(frame.safe_to_bytes()) - self.wfile.flush() diff --git a/test/test_websockets.py b/test/test_websockets.py index d17536383..62268423f 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,19 +1,92 @@ from netlib import tcp from netlib import test -from netlib.websockets import implementations as impl -from netlib.websockets import websockets as ws +from netlib import websockets import os from nose.tools import raises +class WebSocketsEchoHandler(tcp.BaseHandler): + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__( + connection, address, server + ) + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + decoded = websockets.Frame.from_byte_stream(self.rfile.read).decoded_payload + self.on_message(decoded) + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client = False) + self.wfile.write(frame.safe_to_bytes()) + self.wfile.flush() + + def handshake(self): + client_hs = websockets.read_handshake(self.rfile.read, 1) + key = websockets.process_handshake_from_client(client_hs) + response = websockets.create_server_handshake(key) + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.version = "13" + self.client_nounce = websockets.create_client_nounce() + self.resource = "/" + + def connect(self): + super(WebSocketsClient, self).connect() + + handshake = websockets.create_client_handshake( + self.address.host, + self.address.port, + self.client_nounce, + self.version, + self.resource + ) + + self.wfile.write(handshake) + self.wfile.flush() + + server_handshake = websockets.read_handshake(self.rfile.read, 1) + server_nounce = websockets.process_handshake_from_server( + server_handshake, self.client_nounce + ) + + if not server_nounce == websockets.create_server_nounce(self.client_nounce): + self.close() + + def read_next_message(self): + return websockets.Frame.from_byte_stream(self.rfile.read).payload + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client = True) + self.wfile.write(frame.safe_to_bytes()) + self.wfile.flush() + + class TestWebSockets(test.ServerTestBase): - handler = impl.WebSocketsEchoHandler + handler = WebSocketsEchoHandler def random_bytes(self, n = 100): return os.urandom(n) def echo(self, msg): - client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message(msg) response = client.read_next_message() @@ -39,10 +112,10 @@ class TestWebSockets(test.ServerTestBase): default builder should always generate valid frames """ msg = self.random_bytes() - client_frame = ws.Frame.default(msg, from_client = True) + client_frame = websockets.Frame.default(msg, from_client = True) assert client_frame.is_valid() - server_frame = ws.Frame.default(msg, from_client = False) + server_frame = websockets.Frame.default(msg, from_client = False) assert server_frame.is_valid() def test_serialization_bijection(self): @@ -52,26 +125,26 @@ class TestWebSockets(test.ServerTestBase): """ for is_client in [True, False]: for num_bytes in [100, 50000, 150000]: - frame = ws.Frame.default( + frame = websockets.Frame.default( self.random_bytes(num_bytes), is_client ) - assert frame == ws.Frame.from_bytes(frame.to_bytes()) + assert frame == websockets.Frame.from_bytes(frame.to_bytes()) bytes = b'\x81\x11cba' - assert ws.Frame.from_bytes(bytes).to_bytes() == bytes + assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - @raises(ws.WebSocketFrameValidationException) + @raises(websockets.WebSocketFrameValidationException) def test_safe_to_bytes(self): - frame = ws.Frame.default(self.random_bytes(8)) + frame = websockets.Frame.default(self.random_bytes(8)) frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() -class BadHandshakeHandler(impl.WebSocketsEchoHandler): +class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = ws.read_handshake(self.rfile.read, 1) - ws.process_handshake_from_client(client_hs) - response = ws.create_server_handshake("malformed_key") + client_hs = websockets.read_handshake(self.rfile.read, 1) + websockets.process_handshake_from_client(client_hs) + response = websockets.create_server_handshake("malformed_key") self.wfile.write(response) self.wfile.flush() self.handshake_done = True @@ -85,6 +158,6 @@ class TestBadHandshake(test.ServerTestBase): @raises(tcp.NetLibDisconnect) def test(self): - client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message("hello")