diff --git a/netlib/utils.py b/netlib/utils.py index 79077ac60..03a70977c 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -8,6 +8,9 @@ def isascii(s): return False return True +# best way to do it in python 2.x +def bytes_to_int(i): + return int(i.encode('hex'), 16) def cleanBin(s, fixspacing=False): """ diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 000000000..9b4faa337 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py new file mode 100644 index 000000000..1ded3b857 --- /dev/null +++ b/netlib/websockets/implementations.py @@ -0,0 +1,80 @@ +from netlib import tcp +from base64 import b64encode +from StringIO import StringIO +from . import websockets as ws +import struct +import SocketServer +import os + +# Simple websocket client and servers that are used to exercise the functionality in websockets.py +# These are *not* fully RFC6455 compliant + +class WebSocketsEchoHandler(tcp.BaseHandler): + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__(connection, address, server) + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + self.on_message(decoded) + + def send_message(self, message): + frame = ws.WebSocketsFrame.default(message, from_client = False) + self.wfile.write(frame.safe_to_bytes()) + self.wfile.flush() + + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + key = ws.process_handshake_from_client(client_hs) + response = ws.create_server_handshake(key) + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.version = "13" + self.client_nounce = ws.create_client_nounce() + self.resource = "/" + + def connect(self): + super(WebSocketsClient, self).connect() + + handshake = ws.create_client_handshake( + self.address.host, + self.address.port, + self.client_nounce, + self.version, + self.resource + ) + + self.wfile.write(handshake) + self.wfile.flush() + + server_handshake = ws.read_handshake(self.rfile.read, 1) + + server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) + + if not server_nounce == ws.create_server_nounce(self.client_nounce): + self.close() + + def read_next_message(self): + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + + def send_message(self, message): + frame = ws.WebSocketsFrame.default(message, from_client = True) + self.wfile.write(frame.safe_to_bytes()) + self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py new file mode 100644 index 000000000..ea3db21d5 --- /dev/null +++ b/netlib/websockets/websockets.py @@ -0,0 +1,389 @@ +from __future__ import absolute_import + +from base64 import b64encode +from hashlib import sha1 +from mimetools import Message +from netlib import tcp +from netlib import utils +from StringIO import StringIO +import os +import SocketServer +import struct +import io + +# Colleciton of utility functions that implement small portions of the RFC6455 WebSockets Protocol +# Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or completeness +# +# This is a work in progress and does not yet contain all the utilites need to create fully complient client/servers +# +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + +class WebSocketFrameValidationException(Exception): + pass + +class WebSocketsFrame(object): + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + def __init__( + self, + fin, # decmial integer 1 or 0 + opcode, # decmial integer 1 - 4 + mask_bit, # decimal integer 1 or 0 + payload_length_code, # decimal integer 1 - 127 + decoded_payload, # bytestring + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string + actual_payload_length = None, # any decimal integer + ): + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload + self.actual_payload_length = actual_payload_length + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use from_byte_stream() directly + """ + return cls.from_byte_stream(io.BytesIO(bytestring).read) + + + @classmethod + def default(cls, message, from_client = False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + length_code, actual_length = get_payload_length_pair(message) + + if from_client: + mask_bit = 1 + masking_key = random_masking_key() + payload = apply_mask(message, masking_key) + else: + mask_bit = 0 + masking_key = None + payload = message + + return cls( + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, + actual_payload_length = actual_length + ) + + def is_valid(self): + """ + Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame + has not been corrupted. + """ + try: + assert 0 <= self.fin <= 1 + assert 0 <= self.rsv1 <= 1 + assert 0 <= self.rsv2 <= 1 + assert 0 <= self.rsv3 <= 1 + assert 1 <= self.opcode <= 4 + assert 0 <= self.mask_bit <= 1 + assert 1 <= self.payload_length_code <= 127 + + if self.mask_bit == 1: + assert 1 <= len(self.masking_key) <= 4 + else: + assert self.masking_key == None + + assert self.actual_payload_length == len(self.payload) + + if self.payload is not None and self.masking_key is not None: + assert apply_mask(self.payload, self.masking_key) == self.decoded_payload + + return True + except AssertionError: + return False + + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length))]) + + def safe_to_bytes(self): + if self.is_valid(): + return self.to_bytes() + else: + raise WebSocketFrameValidationException() + + def to_bytes(self): + """ + Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees that the + serialized bytes will be correct. see safe_to_bytes() + """ + + max_16_bit_int = (1 << 16) + max_64_bit_int = (1 << 63) + + # break down of the bit-math used to construct the first byte from the frame's integer values + # first shift the significant bit into the correct position + # 00000001 << 7 = 10000000 + # ... + # then combine: + # + # 10000000 fin + # 01000000 res1 + # 00100000 res2 + # 00010000 res3 + # 00000001 opcode + # -------- OR + # 11110001 = first_byte + + first_byte = (self.fin << 7) | (self.rsv1 << 6) | (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + + second_byte = (self.mask_bit << 7) | self.payload_length_code + + bytes = chr(first_byte) + chr(second_byte) + + if self.actual_payload_length < 126: + pass + + elif self.actual_payload_length < max_16_bit_int: + + # '!H' pack as 16 bit unsigned short + bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length + + elif self.actual_payload_length < max_64_bit_int: + # '!Q' = pack as 64 bit unsigned long long + bytes += struct.pack('!Q', self.actual_payload_length) # add 8 bytes extended payload length + + if self.masking_key is not None: + bytes += self.masking_key + + bytes += self.payload # already will be encoded if neccessary + + return bytes + + + @classmethod + def from_byte_stream(cls, read_bytes): + """ + read a websockets frame sent by a server or client + + read_bytes is a function that can be backed + by sockets or by any byte reader. So this + function may be used to read frames from disk/wire/memory + """ + first_byte = utils.bytes_to_int(read_bytes(1)) + second_byte = utils.bytes_to_int(read_bytes(1)) + + fin = first_byte >> 7 # grab the left most bit + opcode = first_byte & 15 # grab right most 4 bits by and-ing with 00001111 + mask_bit = second_byte >> 7 # grab left most bit + payload_length = second_byte & 127 # grab the next 7 bits + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if payload_length <= 125: + actual_payload_length = payload_length + + elif payload_length == 126: + actual_payload_length = utils.bytes_to_int(read_bytes(2)) + + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = read_bytes(4) + else: + masking_key = None + + payload = read_bytes(actual_payload_length) + + if mask_bit == 1: + decoded_payload = apply_mask(payload, masking_key) + else: + decoded_payload = payload + + return cls( + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, + actual_payload_length = actual_payload_length + ) + + def __eq__(self, other): + return ( + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length) + +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + +def random_masking_key(): + return os.urandom(4) + +def create_client_handshake(host, port, key, version, resource): + """ + WebSockets connections are intiated by the client with a valid HTTP upgrade request + """ + headers = [ + ('Host', '%s:%s' % (host, port)), + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ] + request = "GET %s HTTP/1.1" % resource + return build_handshake(headers, request) + +def create_server_handshake(key): + """ + The server response is a valid HTTP 101 response. + """ + headers = [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', create_server_nounce(key)) + ] + request = "HTTP/1.1 101 Switching Protocols" + return build_handshake(headers, request) + + +def build_handshake(headers, request): + handshake = [request.encode('utf-8')] + for header, value in headers: + handshake.append(("%s: %s" % (header, value)).encode('utf-8')) + handshake.append(b'\r\n') + return b'\r\n'.join(handshake) + +def read_handshake(read_bytes, num_bytes_per_read): + """ + From provided function that reads bytes, read in a + complete HTTP request, which terminates with a CLRF + """ + response = b'' + doubleCLRF = b'\r\n\r\n' + while True: + bytes = read_bytes(num_bytes_per_read) + if not bytes: + break + response += bytes + if doubleCLRF in response: + break + return response + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is larger + than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + +def process_handshake_from_client(handshake): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Key'] + return key + +def process_handshake_from_server(handshake, client_nounce): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Accept'] + return key + +def headers_from_http_message(http_message): + return Message(StringIO(http_message.split('\r\n', 1)[1])) + +def create_server_nounce(client_nounce): + return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + +def create_client_nounce(): + return b64encode(os.urandom(16)).decode('utf-8') + diff --git a/test/test_websockets.py b/test/test_websockets.py new file mode 100644 index 000000000..951aa41ff --- /dev/null +++ b/test/test_websockets.py @@ -0,0 +1,83 @@ +from netlib import tcp +from netlib import test +from netlib.websockets import implementations as impl +from netlib.websockets import websockets as ws +import os +from nose.tools import raises + +class TestWebSockets(test.ServerTestBase): + handler = impl.WebSocketsEchoHandler + + def random_bytes(self, n = 100): + return os.urandom(n) + + def echo(self, msg): + client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message(msg) + response = client.read_next_message() + assert response == msg + + def test_simple_echo(self): + self.echo("hello I'm the client") + + def test_frame_sizes(self): + small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length + medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int + large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int + + self.echo(small_msg) + self.echo(medium_msg) + self.echo(large_msg) + + def test_default_builder(self): + """ + default builder should always generate valid frames + """ + msg = self.random_bytes() + client_frame = ws.WebSocketsFrame.default(msg, from_client = True) + assert client_frame.is_valid() + + server_frame = ws.WebSocketsFrame.default(msg, from_client = False) + assert server_frame.is_valid() + + def test_serialization_bijection(self): + """ + Ensure that various frame types can be serialized/deserialized back and forth + between to_bytes() and from_bytes() + """ + for is_client in [True, False]: + for num_bytes in [100, 50000, 150000]: + frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client) + assert frame == ws.WebSocketsFrame.from_bytes(frame.to_bytes()) + + bytes = b'\x81\x11cba' + assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes + + @raises(ws.WebSocketFrameValidationException) + def test_safe_to_bytes(self): + frame = ws.WebSocketsFrame.default(self.random_bytes(8)) + frame.actual_payload_length = 1 #corrupt the frame + frame.safe_to_bytes() + + +class BadHandshakeHandler(impl.WebSocketsEchoHandler): + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + key = ws.process_handshake_from_client(client_hs) + response = ws.create_server_handshake("malformed_key") + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + +class TestBadHandshake(test.ServerTestBase): + """ + Ensure that the client disconnects if the server handshake is malformed + """ + handler = BadHandshakeHandler + + @raises(tcp.NetLibDisconnect) + def test(self): + client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message("hello") \ No newline at end of file