From e41e5cbfdd7b778e6f68e86658e95f9e413133cb Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Thu, 9 Apr 2015 19:35:40 -0700 Subject: [PATCH 01/15] netlib websockets --- netlib/http.py | 14 + netlib/utils.py | 3 + netlib/websockets/__init__.py | 1 + netlib/websockets/implementations.py | 81 ++++++ netlib/websockets/websockets.py | 368 +++++++++++++++++++++++++++ test/test_websockets.py | 15 ++ 6 files changed, 482 insertions(+) create mode 100644 netlib/websockets/__init__.py create mode 100644 netlib/websockets/implementations.py create mode 100644 netlib/websockets/websockets.py create mode 100644 test/test_websockets.py diff --git a/netlib/http.py b/netlib/http.py index 264388636..2c72621dd 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -29,6 +29,20 @@ def _is_valid_host(host): return None return True +def is_successful_upgrade(request, response): + """ + determines if a client and server successfully agreed to an HTTP protocol upgrade + + https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism + """ + http_switching_protocols_code = 101 + + if request and response: + responseUpgrade = request.headers.get("Upgrade") + requestUpgrade = response.headers.get("Upgrade") + if response.code == http_switching_protocols_code and responseUpgrade == requestUpgrade: + return requestUpgrade[0] if len(requestUpgrade) > 0 else None + return None def parse_url(url): """ diff --git a/netlib/utils.py b/netlib/utils.py index 79077ac60..03a70977c 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -8,6 +8,9 @@ def isascii(s): return False return True +# best way to do it in python 2.x +def bytes_to_int(i): + return int(i.encode('hex'), 16) def cleanBin(s, fixspacing=False): """ diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 000000000..9b4faa337 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py new file mode 100644 index 000000000..78ae5be6b --- /dev/null +++ b/netlib/websockets/implementations.py @@ -0,0 +1,81 @@ +from netlib import tcp +from base64 import b64encode +from StringIO import StringIO +from . import websockets as ws +import struct +import SocketServer +import os + +# Simple websocket client and servers that are used to exercise the functionality in websockets.py +# These are *not* fully RFC6455 compliant + +class WebSocketsEchoHandler(tcp.BaseHandler): + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__(connection, address, server) + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + self.on_message(decoded) + + def send_message(self, message): + frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False) + self.wfile.write(frame.to_bytes()) + self.wfile.flush() + + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + key = ws.server_process_handshake(client_hs) + response = ws.create_server_handshake(key) + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.version = "13" + self.key = b64encode(os.urandom(16)).decode('utf-8') + self.resource = "/" + + def connect(self): + super(WebSocketsClient, self).connect() + + handshake = ws.create_client_handshake( + self.address.host, + self.address.port, + self.key, + self.version, + self.resource + ) + + self.wfile.write(handshake) + self.wfile.flush() + + response = ws.read_handshake(self.rfile.read, 1) + + if not response: + self.close() + + def read_next_message(self): + try: + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + except IndexError: + self.close() + + def send_message(self, message): + frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True) + self.wfile.write(frame.to_bytes()) + self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py new file mode 100644 index 000000000..b796ce399 --- /dev/null +++ b/netlib/websockets/websockets.py @@ -0,0 +1,368 @@ +from __future__ import absolute_import + +from base64 import b64encode +from hashlib import sha1 +from mimetools import Message +from netlib import tcp +from netlib import utils +from StringIO import StringIO +import os +import SocketServer +import struct +import io + +# Colleciton of utility functions that implement small portions of the RFC6455 WebSockets Protocol +# Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or completeness +# +# This is a work in progress and does not yet contain all the utilites need to create fully complient client/servers +# +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + +class WebSocketFrameValidationException(Exception): + pass + +class WebSocketsFrame(object): + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + def __init__( + self, + fin, # decmial integer 1 or 0 + opcode, # decmial integer 1 - 4 + mask_bit, # decimal integer 1 or 0 + payload_length_code, # decimal integer 1 - 127 + decoded_payload, # bytestring + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string + actual_payload_length = None, # any decimal integer + use_validation = True # indicates whether or not you care if this frame adheres to the spec + ): + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload + self.actual_payload_length = actual_payload_length + self.use_validation = use_validation + + if self.use_validation: + self.validate_frame() + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use read_frame() directly + """ + self.from_byte_stream(io.BytesIO(bytestring).read) + + @classmethod + def default_frame_from_message(cls, message, from_client = False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + length_code, actual_length = get_payload_length_pair(message) + + if from_client: + mask_bit = 1 + masking_key = random_masking_key() + payload = apply_mask(message, masking_key) + else: + mask_bit = 0 + masking_key = None + payload = message + + return cls( + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, + actual_payload_length = actual_length + ) + + def validate_frame(self): + """ + Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame + has not been corrupted. + """ + try: + assert 0 <= self.fin <= 1 + assert 0 <= self.rsv1 <= 1 + assert 0 <= self.rsv2 <= 1 + assert 0 <= self.rsv3 <= 1 + assert 1 <= self.opcode <= 4 + assert 0 <= self.mask_bit <= 1 + assert 1 <= self.payload_length_code <= 127 + + if self.mask_bit == 1: + assert 1 <= len(self.masking_key) <= 4 + else: + assert self.masking_key == None + + assert self.actual_payload_length == len(self.payload) + + if self.payload is not None and self.masking_key is not None: + apply_mask(self.payload, self.masking_key) == self.decoded_payload + + except AssertionError: + raise WebSocketFrameValidationException() + + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)), + ("use_validation - " + str(self.use_validation))]) + + def to_bytes(self): + """ + Serialize the frame back into the wire format, returns a bytestring + """ + # validate enforces all the assumptions made by this serializer + # in the spritit of mitmproxy, it's possible to create and serialize invalid frames + # by skipping validation. + if self.use_validation: + self.validate_frame() + + max_16_bit_int = (1 << 16) + max_64_bit_int = (1 << 63) + + # break down of the bit-math used to construct the first byte from the frame's integer values + # first shift the significant bit into the correct position + # 00000001 << 7 = 10000000 + # ... + # then combine: + # + # 10000000 fin + # 01000000 res1 + # 00100000 res2 + # 00010000 res3 + # 00000001 opcode + # -------- OR + # 11110001 = first_byte + + first_byte = (self.fin << 7) | (self.rsv1 << 6) | (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + + second_byte = (self.mask_bit << 7) | self.payload_length_code + + bytes = chr(first_byte) + chr(second_byte) + + if self.actual_payload_length < 126: + pass + + elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short + bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length + + elif self.actual_payload_length < max_64_bit_int: + # '!Q' = pack as 64 bit unsigned long long + bytes += struct.pack('!Q', self.actual_payload_length) # add 8 bytes extended payload length + + if self.masking_key is not None: + bytes += self.masking_key + + bytes += self.payload # already will be encoded if neccessary + + return bytes + + + @classmethod + def from_byte_stream(cls, read_bytes): + """ + read a websockets frame sent by a server or client + + read_bytes is a function that can be backed + by sockets or by any byte reader. So this + function may be used to read frames from disk/wire/memory + """ + first_byte = utils.bytes_to_int(read_bytes(1)) + second_byte = utils.bytes_to_int(read_bytes(1)) + + fin = first_byte >> 7 # grab the left most bit + opcode = first_byte & 15 # grab right most 4 bits by and-ing with 00001111 + mask_bit = second_byte >> 7 # grab left most bit + payload_length = second_byte & 127 # grab the next 7 bits + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if payload_length <= 125: + actual_payload_length = payload_length + + elif payload_length == 126: + actual_payload_length = utils.bytes_to_int(read_bytes(2)) + + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = read_bytes(4) + else: + masking_key = None + + payload = read_bytes(actual_payload_length) + + if mask_bit == 1: + decoded_payload = apply_mask(payload, masking_key) + else: + decoded_payload = payload + + return cls( + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, + actual_payload_length = actual_payload_length + ) + +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + +def random_masking_key(): + return os.urandom(4) + +def masking_key_list(masking_key): + return [utils.bytes_to_int(byte) for byte in masking_key] + +def create_client_handshake(host, port, key, version, resource): + """ + WebSockets connections are intiated by the client with a valid HTTP upgrade request + """ + headers = [ + ('Host', '%s:%s' % (host, port)), + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ] + request = "GET %s HTTP/1.1" % resource + return build_handshake(headers, request) + + +def create_server_handshake(key, magic = websockets_magic): + """ + The server response is a valid HTTP 101 response. + """ + digest = b64encode(sha1(key + magic).hexdigest().decode('hex')) + headers = [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', digest) + ] + request = "HTTP/1.1 101 Switching Protocols" + return build_handshake(headers, request) + + +def build_handshake(headers, request): + handshake = [request.encode('utf-8')] + for header, value in headers: + handshake.append(("%s: %s" % (header, value)).encode('utf-8')) + handshake.append(b'\r\n') + return b'\r\n'.join(handshake) + + +def read_handshake(read_bytes, num_bytes_per_read): + """ + From provided function that reads bytes, read in a + complete HTTP request, which terminates with a CLRF + """ + response = b'' + doubleCLRF = b'\r\n\r\n' + while True: + bytes = read_bytes(num_bytes_per_read) + if not bytes: + break + response += bytes + if doubleCLRF in response: + break + return response + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is larger + than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + +def server_process_handshake(handshake): + headers = Message(StringIO(handshake.split('\r\n', 1)[1])) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Key'] + return key + +def generate_client_nounce(): + return b64encode(os.urandom(16)).decode('utf-8') + diff --git a/test/test_websockets.py b/test/test_websockets.py new file mode 100644 index 000000000..d7e1627fa --- /dev/null +++ b/test/test_websockets.py @@ -0,0 +1,15 @@ +from netlib import test +from netlib.websockets import implementations as ws + +class TestWebSockets(test.ServerTestBase): + handler = ws.WebSocketsEchoHandler + + def test_websockets_echo(self): + msg = "hello I'm the client" + client = ws.WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message(msg) + response = client.read_next_message() + print "Assert response: " + response + " == msg: " + msg + assert response == msg + From 0edc04814e3affa71025938ac354707b9b4c481c Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 11:35:15 -0700 Subject: [PATCH 02/15] small cleanups, working on tests --- netlib/websockets/implementations.py | 10 ++++---- netlib/websockets/websockets.py | 35 ++++++++++++++-------------- test/test_websockets.py | 24 +++++++++++++++---- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 78ae5be6b..ff42ff658 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -26,8 +26,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.on_message(decoded) def send_message(self, message): - frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False) - self.wfile.write(frame.to_bytes()) + frame = ws.WebSocketsFrame.default(message, from_client = False) + self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() def handshake(self): @@ -47,7 +47,7 @@ class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) self.version = "13" - self.key = b64encode(os.urandom(16)).decode('utf-8') + self.key = ws.generate_client_nounce() self.resource = "/" def connect(self): @@ -76,6 +76,6 @@ class WebSocketsClient(tcp.TCPClient): self.close() def send_message(self, message): - frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True) - self.wfile.write(frame.to_bytes()) + frame = ws.WebSocketsFrame.default(message, from_client = True) + self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index b796ce399..527d55d62 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -65,7 +65,6 @@ class WebSocketsFrame(object): payload = None, # bytestring masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer - use_validation = True # indicates whether or not you care if this frame adheres to the spec ): self.fin = fin self.rsv1 = rsv1 @@ -78,21 +77,18 @@ class WebSocketsFrame(object): self.payload = payload self.decoded_payload = decoded_payload self.actual_payload_length = actual_payload_length - self.use_validation = use_validation - - if self.use_validation: - self.validate_frame() @classmethod def from_bytes(cls, bytestring): """ Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use read_frame() directly + to construct a frame from a stream of bytes, use from_byte_stream() directly """ self.from_byte_stream(io.BytesIO(bytestring).read) + @classmethod - def default_frame_from_message(cls, message, from_client = False): + def default(cls, message, from_client = False): """ Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. @@ -119,7 +115,7 @@ class WebSocketsFrame(object): actual_payload_length = actual_length ) - def validate_frame(self): + def frame_is_valid(self): """ Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame has not been corrupted. @@ -141,10 +137,11 @@ class WebSocketsFrame(object): assert self.actual_payload_length == len(self.payload) if self.payload is not None and self.masking_key is not None: - apply_mask(self.payload, self.masking_key) == self.decoded_payload + assert apply_mask(self.payload, self.masking_key) == self.decoded_payload + return True except AssertionError: - raise WebSocketFrameValidationException() + return False def human_readable(self): return "\n".join([ @@ -161,15 +158,19 @@ class WebSocketsFrame(object): ("actual_payload_length - " + str(self.actual_payload_length)), ("use_validation - " + str(self.use_validation))]) + def safe_to_bytes(self): + try: + assert self.frame_is_valid() + return self.to_bytes() + except: + raise WebSocketFrameValidationException() + def to_bytes(self): """ Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees that the + serialized bytes will be correct. see safe_to_bytes() """ - # validate enforces all the assumptions made by this serializer - # in the spritit of mitmproxy, it's possible to create and serialize invalid frames - # by skipping validation. - if self.use_validation: - self.validate_frame() max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) @@ -198,6 +199,7 @@ class WebSocketsFrame(object): pass elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length @@ -284,9 +286,6 @@ def apply_mask(message, masking_key): def random_masking_key(): return os.urandom(4) -def masking_key_list(masking_key): - return [utils.bytes_to_int(byte) for byte in masking_key] - def create_client_handshake(host, port, key, version, resource): """ WebSockets connections are intiated by the client with a valid HTTP upgrade request diff --git a/test/test_websockets.py b/test/test_websockets.py index d7e1627fa..0b2647ef3 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,15 +1,29 @@ from netlib import test -from netlib.websockets import implementations as ws +from netlib.websockets import implementations as impl +from netlib.websockets import websockets as ws +import os class TestWebSockets(test.ServerTestBase): - handler = ws.WebSocketsEchoHandler + handler = impl.WebSocketsEchoHandler - def test_websockets_echo(self): - msg = "hello I'm the client" - client = ws.WebSocketsClient(("127.0.0.1", self.port)) + def echo(self, msg): + client = impl.WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message(msg) response = client.read_next_message() print "Assert response: " + response + " == msg: " + msg assert response == msg + def test_simple_echo(self): + self.echo("hello I'm the client") + + def test_frame_sizes(self): + small_string = os.urandom(100) # length can fit in the the 7 bit payload length + medium_string = os.urandom(50000) # 50kb, sligthly larger than can fit in a 7 bit int + large_string = os.urandom(150000) # 150kb, slightly larger than can fit in a 16 bit int + + self.echo(small_string) + self.echo(medium_string) + self.echo(large_string) + + From 73ce169e3d11eeabeb78143bd86edfdbc3e07fd9 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 12 Apr 2015 10:26:09 +1200 Subject: [PATCH 03/15] Initial outline of a cookie parsing and serialization module. --- .env | 5 ++ netlib/http_cookies.py | 133 ++++++++++++++++++++++++++++++++++++++ test/test_http_cookies.py | 106 ++++++++++++++++++++++++++++++ 3 files changed, 244 insertions(+) create mode 100644 .env create mode 100644 netlib/http_cookies.py create mode 100644 test/test_http_cookies.py diff --git a/.env b/.env new file mode 100644 index 000000000..7f847e29f --- /dev/null +++ b/.env @@ -0,0 +1,5 @@ +DIR=`dirname $0` +if [ -z "$VIRTUAL_ENV" ] && [ -f $DIR/../venv.mitmproxy/bin/activate ]; then + echo "Activating mitmproxy virtualenv..." + source $DIR/../venv.mitmproxy/bin/activate +fi diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py new file mode 100644 index 000000000..e11e0f904 --- /dev/null +++ b/netlib/http_cookies.py @@ -0,0 +1,133 @@ +""" +A flexible module for cookie parsing and manipulation. + +We try to be as permissive as possible. Parsing accepts formats from RFC6265 an +RFC2109. Serialization follows RFC6265 strictly. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 +""" + +import re + +import odict + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start+1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i+1], i+1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + for i in range(start+1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + pass + else: + ret.append(s[i]) + return "".join(ret), i+1 + + +def _read_value(s, start): + """ + Reads a value - the RHS of a token/value pair in a cookie. + """ + if s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, ";,") + + +def _read_pairs(s): + """ + Read pairs of lhs=rhs values. + """ + off = 0 + vals = [] + while 1: + lhs, off = _read_token(s, off) + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off+1) + vals.append([lhs.lstrip(), rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +ESCAPE = re.compile(r"([\"\\])") +SPECIAL = re.compile(r"^\w+$") + + +def _format_pairs(lst): + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + match = SPECIAL.search(v) + if match: + v = ESCAPE.sub(r"\1", v) + vals.append("%s=%s"%(k, v)) + return "; ".join(vals) + + +def parse_cookies(s): + """ + Parses a Cookie header value. + Returns an ODict object. + """ + pairs, off = _read_pairs(s) + return odict.ODict(pairs) + + +def unparse_cookies(od): + """ + Formats a Cookie header value. + """ + vals = [] + for i in od.lst: + vals.append("%s=%s"%(i[0], i[1])) + return "; ".join(vals) + + + +def parse_set_cookies(s): + start = 0 + + +def unparse_set_cookies(s): + pass diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py new file mode 100644 index 000000000..b3f1f9146 --- /dev/null +++ b/test/test_http_cookies.py @@ -0,0 +1,106 @@ +from netlib import http_cookies, odict +import nose.tools + + +def test_read_token(): + tokens = [ + [("foo", 0), ("foo", 3)], + [("foo", 1), ("oo", 3)], + [(" foo", 1), ("foo", 4)], + [(" foo;", 1), ("foo", 4)], + [(" foo=", 1), ("foo", 4)], + [(" foo=bar", 1), ("foo", 4)], + ] + for q, a in tokens: + nose.tools.eq_(http_cookies._read_token(*q), a) + + +def test_read_quoted_string(): + tokens = [ + [('"foo" x', 0), ("foo", 5)], + [('"f\oo" x', 0), ("foo", 6)], + [(r'"f\\o" x', 0), (r"f\o", 6)], + [(r'"f\\" x', 0), (r"f" + '\\', 5)], + [('"fo\\\"" x', 0), ("fo\"", 6)], + ] + for q, a in tokens: + nose.tools.eq_(http_cookies._read_quoted_string(*q), a) + + +def test_read_pairs(): + vals = [ + [ + "one", + [["one", None]] + ], + [ + "one=two", + [["one", "two"]] + ], + [ + 'one="two"', + [["one", "two"]] + ], + [ + 'one="two"; three=four', + [["one", "two"], ["three", "four"]] + ], + [ + 'one="two"; three=four; five', + [["one", "two"], ["three", "four"], ["five", None]] + ], + [ + 'one="\\"two"; three=four', + [["one", '"two'], ["three", "four"]] + ], + ] + for s, lst in vals: + ret, off = http_cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + + +def test_pairs_roundtrips(): + pairs = [ + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one", + [["one", None]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="uno"; two="\due"', + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="un\\"o"', + [["one", 'un"o']] + ], + [ + "one=uno; two; three=tre", + [["one", "uno"], ["two", None], ["three", "tre"]] + ], + [ + "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " + "_rcc2=53VdltWl+Ov6ordflA==;", + [ + ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], + ["_rcc2", "53VdltWl+Ov6ordflA=="] + ] + ] + ] + for s, lst in pairs: + ret, off = http_cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + s2 = http_cookies._format_pairs(lst) + ret, off = http_cookies._read_pairs(s2) + nose.tools.eq_(ret, lst) + + +def test_parse_set_cookie(): + pass From 2630da7263242411d413b5e4b2c520d29848c918 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 12 Apr 2015 11:26:02 +1200 Subject: [PATCH 04/15] cookies: Cater for special values, fix some bugs found in real-world testing --- netlib/http_cookies.py | 48 +++++++++++++++++++++++++++------------ test/test_http_cookies.py | 8 +++++++ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index e11e0f904..826754180 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -59,29 +59,39 @@ def _read_quoted_string(s, start): return "".join(ret), i+1 -def _read_value(s, start): +def _read_value(s, start, special): """ Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. """ - if s[start] == '"': + if start >= len(s): + return "", start + elif s[start] == '"': return _read_quoted_string(s, start) + elif special: + return _read_until(s, start, ";") else: return _read_until(s, start, ";,") -def _read_pairs(s): +def _read_pairs(s, specials=()): """ Read pairs of lhs=rhs values. + + specials: A lower-cased list of keys that may contain commas. """ off = 0 vals = [] while 1: lhs, off = _read_token(s, off) + lhs = lhs.lstrip() rhs = None if off < len(s): if s[off] == "=": - rhs, off = _read_value(s, off+1) - vals.append([lhs.lstrip(), rhs]) + rhs, off = _read_value(s, off+1, lhs.lower() in specials) + vals.append([lhs, rhs]) off += 1 if not off < len(s): break @@ -89,18 +99,30 @@ def _read_pairs(s): ESCAPE = re.compile(r"([\"\\])") -SPECIAL = re.compile(r"^\w+$") -def _format_pairs(lst): +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +def _format_pairs(lst, specials=()): + """ + specials: A lower-cased list of keys that will not be quoted. + """ vals = [] for k, v in lst: if v is None: vals.append(k) else: - match = SPECIAL.search(v) - if match: - v = ESCAPE.sub(r"\1", v) + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"'%v vals.append("%s=%s"%(k, v)) return "; ".join(vals) @@ -118,11 +140,7 @@ def unparse_cookies(od): """ Formats a Cookie header value. """ - vals = [] - for i in od.lst: - vals.append("%s=%s"%(i[0], i[1])) - return "; ".join(vals) - + return _format_pairs(od.lst) def parse_set_cookies(s): diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py index b3f1f9146..31e5f0b0d 100644 --- a/test/test_http_cookies.py +++ b/test/test_http_cookies.py @@ -37,6 +37,10 @@ def test_read_pairs(): "one=two", [["one", "two"]] ], + [ + "one=", + [["one", ""]] + ], [ 'one="two"', [["one", "two"]] @@ -81,6 +85,10 @@ def test_pairs_roundtrips(): 'one="un\\"o"', [["one", 'un"o']] ], + [ + 'one="uno,due"', + [["one", 'uno,due']] + ], [ "one=uno; two; three=tre", [["one", "uno"], ["two", None], ["three", "tre"]] From f131f9b855e77554072415c925ed112ec74ee48a Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 15:40:18 -0700 Subject: [PATCH 05/15] handshake tests, serialization test --- netlib/websockets/implementations.py | 19 +++++---- netlib/websockets/websockets.py | 51 +++++++++++++++------- test/test_websockets.py | 63 ++++++++++++++++++++++++---- 3 files changed, 105 insertions(+), 28 deletions(-) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index ff42ff658..73a846905 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -32,7 +32,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.server_process_handshake(client_hs) + key = ws.process_handshake_from_client(client_hs) response = ws.create_server_handshake(key) self.wfile.write(response) self.wfile.flush() @@ -46,9 +46,9 @@ class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.key = ws.generate_client_nounce() - self.resource = "/" + self.version = "13" + self.client_nounce = ws.create_client_nounce() + self.resource = "/" def connect(self): super(WebSocketsClient, self).connect() @@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient): handshake = ws.create_client_handshake( self.address.host, self.address.port, - self.key, + self.client_nounce, self.version, self.resource ) @@ -64,9 +64,14 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(handshake) self.wfile.flush() - response = ws.read_handshake(self.rfile.read, 1) + server_handshake = ws.read_handshake(self.rfile.read, 1) - if not response: + if not server_handshake: + self.close() + + server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) + + if not server_nounce == ws.create_server_nounce(self.client_nounce): self.close() def read_next_message(self): diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 527d55d62..cf9a68aa9 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -84,7 +84,7 @@ class WebSocketsFrame(object): Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_byte_stream() directly """ - self.from_byte_stream(io.BytesIO(bytestring).read) + return cls.from_byte_stream(io.BytesIO(bytestring).read) @classmethod @@ -115,7 +115,7 @@ class WebSocketsFrame(object): actual_payload_length = actual_length ) - def frame_is_valid(self): + def is_valid(self): """ Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame has not been corrupted. @@ -155,12 +155,11 @@ class WebSocketsFrame(object): ("masking_key - " + str(self.masking_key)), ("payload - " + str(self.payload)), ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length)), - ("use_validation - " + str(self.use_validation))]) + ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): try: - assert self.frame_is_valid() + assert self.is_valid() return self.to_bytes() except: raise WebSocketFrameValidationException() @@ -197,7 +196,7 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - + elif self.actual_payload_length < max_16_bit_int: # '!H' pack as 16 bit unsigned short @@ -267,6 +266,20 @@ class WebSocketsFrame(object): actual_payload_length = actual_payload_length ) + def __eq__(self, other): + return ( + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length) + def apply_mask(message, masking_key): """ Data sent from the server must be masked to prevent malicious clients @@ -300,16 +313,14 @@ def create_client_handshake(host, port, key, version, resource): request = "GET %s HTTP/1.1" % resource return build_handshake(headers, request) - -def create_server_handshake(key, magic = websockets_magic): +def create_server_handshake(key): """ The server response is a valid HTTP 101 response. """ - digest = b64encode(sha1(key + magic).hexdigest().decode('hex')) headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', digest) + ('Sec-WebSocket-Accept', create_server_nounce(key)) ] request = "HTTP/1.1 101 Switching Protocols" return build_handshake(headers, request) @@ -322,7 +333,6 @@ def build_handshake(headers, request): handshake.append(b'\r\n') return b'\r\n'.join(handshake) - def read_handshake(read_bytes, num_bytes_per_read): """ From provided function that reads bytes, read in a @@ -355,13 +365,26 @@ def get_payload_length_pair(payload_bytestring): length_code = 127 return (length_code, actual_length) -def server_process_handshake(handshake): - headers = Message(StringIO(handshake.split('\r\n', 1)[1])) +def process_handshake_from_client(handshake): + headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": return key = headers['Sec-WebSocket-Key'] return key -def generate_client_nounce(): +def process_handshake_from_server(handshake, client_nounce): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Accept'] + return key + +def headers_from_http_message(http_message): + return Message(StringIO(http_message.split('\r\n', 1)[1])) + +def create_server_nounce(client_nounce): + return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + +def create_client_nounce(): return b64encode(os.urandom(16)).decode('utf-8') diff --git a/test/test_websockets.py b/test/test_websockets.py index 0b2647ef3..a5ebf3d1b 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,29 +1,78 @@ +from netlib import tcp from netlib import test from netlib.websockets import implementations as impl from netlib.websockets import websockets as ws import os +from nose.tools import raises class TestWebSockets(test.ServerTestBase): handler = impl.WebSocketsEchoHandler + def random_bytes(self, n = 100): + return os.urandom(n) + def echo(self, msg): client = impl.WebSocketsClient(("127.0.0.1", self.port)) client.connect() client.send_message(msg) response = client.read_next_message() - print "Assert response: " + response + " == msg: " + msg assert response == msg def test_simple_echo(self): self.echo("hello I'm the client") def test_frame_sizes(self): - small_string = os.urandom(100) # length can fit in the the 7 bit payload length - medium_string = os.urandom(50000) # 50kb, sligthly larger than can fit in a 7 bit int - large_string = os.urandom(150000) # 150kb, slightly larger than can fit in a 16 bit int + small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length + medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int + large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int + + self.echo(small_msg) + self.echo(medium_msg) + self.echo(large_msg) + + def test_default_builder(self): + """ + default builder should always generate valid frames + """ + msg = self.random_bytes() + client_frame = ws.WebSocketsFrame.default(msg, from_client = True) + assert client_frame.is_valid() + + server_frame = ws.WebSocketsFrame.default(msg, from_client = False) + assert server_frame.is_valid() + + def test_serialization_bijection(self): + for is_client in [True, False]: + for num_bytes in [100, 50000, 150000]: + frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client) + assert frame == ws.WebSocketsFrame.from_bytes(frame.to_bytes()) + + bytes = b'\x81\x11cba' + assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes + + +class BadHandshakeHandler(impl.WebSocketsEchoHandler): + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + key = ws.process_handshake_from_client(client_hs) + response = ws.create_server_handshake("malformed_key") + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + +class TestBadHandshake(test.ServerTestBase): + """ + Ensure that the client disconnects if the server handshake is malformed + """ + handler = BadHandshakeHandler + + @raises(tcp.NetLibDisconnect) + def test(self): + client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message("hello") + + - self.echo(small_string) - self.echo(medium_string) - self.echo(large_string) From 0ed2a290639833d772b89cf333577820e84f8204 Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 17:28:52 -0700 Subject: [PATCH 06/15] whitespace --- test/test_websockets.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/test_websockets.py b/test/test_websockets.py index a5ebf3d1b..0c23e355f 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -70,9 +70,4 @@ class TestBadHandshake(test.ServerTestBase): def test(self): client = impl.WebSocketsClient(("127.0.0.1", self.port)) client.connect() - client.send_message("hello") - - - - - + client.send_message("hello") \ No newline at end of file From 2d72a1b6b56f1643cd1d8be59eee55aa7ca2f17f Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Mon, 13 Apr 2015 13:36:09 -0700 Subject: [PATCH 07/15] 100% test coverage, though still need plenty more --- netlib/http.py | 14 -------------- netlib/websockets/implementations.py | 10 ++-------- netlib/websockets/websockets.py | 9 ++++----- test/test_websockets.py | 14 ++++++++++++-- 4 files changed, 18 insertions(+), 29 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 2c72621dd..264388636 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -29,20 +29,6 @@ def _is_valid_host(host): return None return True -def is_successful_upgrade(request, response): - """ - determines if a client and server successfully agreed to an HTTP protocol upgrade - - https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism - """ - http_switching_protocols_code = 101 - - if request and response: - responseUpgrade = request.headers.get("Upgrade") - requestUpgrade = response.headers.get("Upgrade") - if response.code == http_switching_protocols_code and responseUpgrade == requestUpgrade: - return requestUpgrade[0] if len(requestUpgrade) > 0 else None - return None def parse_url(url): """ diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 73a846905..1ded3b857 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -65,9 +65,6 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.flush() server_handshake = ws.read_handshake(self.rfile.read, 1) - - if not server_handshake: - self.close() server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) @@ -75,11 +72,8 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - try: - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload - except IndexError: - self.close() - + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + def send_message(self, message): frame = ws.WebSocketsFrame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index cf9a68aa9..ea3db21d5 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -158,11 +158,10 @@ class WebSocketsFrame(object): ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): - try: - assert self.is_valid() - return self.to_bytes() - except: - raise WebSocketFrameValidationException() + if self.is_valid(): + return self.to_bytes() + else: + raise WebSocketFrameValidationException() def to_bytes(self): """ diff --git a/test/test_websockets.py b/test/test_websockets.py index 0c23e355f..951aa41ff 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -22,8 +22,8 @@ class TestWebSockets(test.ServerTestBase): self.echo("hello I'm the client") def test_frame_sizes(self): - small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length - medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int + small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length + medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int self.echo(small_msg) @@ -42,6 +42,10 @@ class TestWebSockets(test.ServerTestBase): assert server_frame.is_valid() def test_serialization_bijection(self): + """ + Ensure that various frame types can be serialized/deserialized back and forth + between to_bytes() and from_bytes() + """ for is_client in [True, False]: for num_bytes in [100, 50000, 150000]: frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client) @@ -50,6 +54,12 @@ class TestWebSockets(test.ServerTestBase): bytes = b'\x81\x11cba' assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes + @raises(ws.WebSocketFrameValidationException) + def test_safe_to_bytes(self): + frame = ws.WebSocketsFrame.default(self.random_bytes(8)) + frame.actual_payload_length = 1 #corrupt the frame + frame.safe_to_bytes() + class BadHandshakeHandler(impl.WebSocketsEchoHandler): def handshake(self): From de9e7411253c4f67ea4d0b96f6f9e952024c5fa3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 10:02:10 +1200 Subject: [PATCH 08/15] Firm up cookie parsing and formatting API Make a tough call: we won't support old-style comma-separated set-cookie headers. Real world testing has shown that the latest rfc (6265) is often violated in ways that make the parsing problem indeterminate. Since this is much more common than the old style deprecated set-cookie variant, we focus on the most useful case. --- netlib/http_cookies.py | 114 +++++++++++++++++++++++++++---------- test/test_http_cookies.py | 115 +++++++++++++++++++++++++++++++++++++- 2 files changed, 196 insertions(+), 33 deletions(-) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 826754180..a1f240f5c 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -1,13 +1,27 @@ """ A flexible module for cookie parsing and manipulation. -We try to be as permissive as possible. Parsing accepts formats from RFC6265 an -RFC2109. Serialization follows RFC6265 strictly. +This module differs from usual standards-compliant cookie modules in a number of +ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple cookies +to be set in a single header. Technically this should be feasible, but it turns +out that violations of RFC6265 that makes the parsing problem indeterminate are +much more common than genuine occurences of the multi-cookie variants. +Serialization follows RFC6265. http://tools.ietf.org/html/rfc6265 http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 """ +# TODO +# - Disallow LHS-only Cookie values + import re import odict @@ -59,7 +73,7 @@ def _read_quoted_string(s, start): return "".join(ret), i+1 -def _read_value(s, start, special): +def _read_value(s, start, delims): """ Reads a value - the RHS of a token/value pair in a cookie. @@ -70,37 +84,41 @@ def _read_value(s, start, special): return "", start elif s[start] == '"': return _read_quoted_string(s, start) - elif special: - return _read_until(s, start, ";") else: - return _read_until(s, start, ";,") + return _read_until(s, start, delims) -def _read_pairs(s, specials=()): +def _read_pairs(s, off=0, term=None, specials=()): """ Read pairs of lhs=rhs values. - specials: A lower-cased list of keys that may contain commas. + off: start offset + term: if True, treat a comma as a terminator for the pairs lists + specials: a lower-cased list of keys that may contain commas if term is + True """ - off = 0 vals = [] while 1: lhs, off = _read_token(s, off) lhs = lhs.lstrip() - rhs = None - if off < len(s): - if s[off] == "=": - rhs, off = _read_value(s, off+1, lhs.lower() in specials) - vals.append([lhs, rhs]) + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + if term and lhs.lower() not in specials: + delims = ";," + else: + delims = ";" + rhs, off = _read_value(s, off+1, delims) + vals.append([lhs, rhs]) off += 1 if not off < len(s): break + if term and s[off-1] == ",": + break return vals, off -ESCAPE = re.compile(r"([\"\\])") - - def _has_special(s): for i in s: if i in '",;\\': @@ -111,6 +129,9 @@ def _has_special(s): return False +ESCAPE = re.compile(r"([\"\\])") + + def _format_pairs(lst, specials=()): """ specials: A lower-cased list of keys that will not be quoted. @@ -127,25 +148,58 @@ def _format_pairs(lst, specials=()): return "; ".join(vals) -def parse_cookies(s): +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials = ("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): """ - Parses a Cookie header value. - Returns an ODict object. + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. """ - pairs, off = _read_pairs(s) + pairs, off = _read_pairs( + s, + specials = ("expires", "path") + ) + return pairs + + +def parse_set_cookie_header(str): + """ + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. + """ + pairs = _parse_set_cookie_pairs(str) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) + + +def parse_cookie_header(str): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off = _read_pairs(str) return odict.ODict(pairs) -def unparse_cookies(od): +def format_cookie_header(od): """ Formats a Cookie header value. """ return _format_pairs(od.lst) - - -def parse_set_cookies(s): - start = 0 - - -def unparse_set_cookies(s): - pass diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py index 31e5f0b0d..c0e5a5b76 100644 --- a/test/test_http_cookies.py +++ b/test/test_http_cookies.py @@ -1,6 +1,8 @@ -from netlib import http_cookies, odict +import pprint import nose.tools +from netlib import http_cookies, odict + def test_read_token(): tokens = [ @@ -65,6 +67,10 @@ def test_read_pairs(): def test_pairs_roundtrips(): pairs = [ + [ + "", + [] + ], [ "one=uno", [["one", "uno"]] @@ -110,5 +116,108 @@ def test_pairs_roundtrips(): nose.tools.eq_(ret, lst) -def test_parse_set_cookie(): - pass +def test_cookie_roundtrips(): + pairs = [ + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + ] + for s, lst in pairs: + ret = http_cookies.parse_cookie_header(s) + nose.tools.eq_(ret.lst, lst) + s2 = http_cookies.format_cookie_header(ret) + ret = http_cookies.parse_cookie_header(s2) + nose.tools.eq_(ret.lst, lst) + + +# TODO +# I've seen the following pathological cookie in the wild: +# +# cid=09,0,0,0,0; expires=Wed, 10-Jun-2015 21:54:53 GMT; path=/ +# +# It's not compliant under any RFC - the latest RFC prohibits commas in cookie +# values completely, earlier RFCs require them to be within a quoted string. +# +# If we ditch support for earlier RFCs, we can handle this correctly. This +# leaves us with the question: what's more common, multiple-value Set-Cookie +# headers, or Set-Cookie headers that violate the standards? + +def test_parse_set_cookie_pairs(): + pairs = [ + [ + "one=uno", + [ + ["one", "uno"] + ] + ], + [ + "one=uno; foo", + [ + ["one", "uno"], + ["foo", None] + ] + ], + [ + "mun=1.390.f60; " + "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " + "domain=b.aol.com", + [ + ["mun", "1.390.f60"], + ["expires", "sun, 11-oct-2015 12:38:31 gmt"], + ["path", "/"], + ["domain", "b.aol.com"] + ] + ], + [ + r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' + 'domain=.rubiconproject.com; ' + 'expires=mon, 11-may-2015 21:54:57 gmt; ' + 'path=/', + [ + ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], + ['domain', '.rubiconproject.com'], + ['expires', 'mon, 11-may-2015 21:54:57 gmt'], + ['path', '/'] + ] + ], + ] + for s, lst in pairs: + ret = http_cookies._parse_set_cookie_pairs(s) + nose.tools.eq_(ret, lst) + s2 = http_cookies._format_set_cookie_pairs(ret) + ret2 = http_cookies._parse_set_cookie_pairs(s2) + nose.tools.eq_(ret2, lst) + + +def test_parse_set_cookie_header(): + vals = [ + [ + "", None + ], + [ + "one=uno", + ("one", "uno", []) + ], + [ + "one=uno; foo=bar", + ("one", "uno", [["foo", "bar"]]) + ] + ] + for s, expected in vals: + ret = http_cookies.parse_set_cookie_header(s) + if expected: + assert ret[0] == expected[0] + assert ret[1] == expected[1] + nose.tools.eq_(ret[2].lst, expected[2]) + s2 = http_cookies.format_set_cookie_header(*ret) + ret2 = http_cookies.parse_set_cookie_header(s2) + assert ret2[0] == expected[0] + assert ret2[1] == expected[1] + nose.tools.eq_(ret2[2].lst, expected[2]) + else: + assert ret is None From 6db5e0a4a133e6e6150f9cab87cd56b40d6db0b2 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 10:13:03 +1200 Subject: [PATCH 09/15] Remove old-style set-cookie cruft, unit tests to 100% --- netlib/http_cookies.py | 14 +++----------- test/test_http_cookies.py | 6 ++++++ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index a1f240f5c..297efb80d 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -88,14 +88,12 @@ def _read_value(s, start, delims): return _read_until(s, start, delims) -def _read_pairs(s, off=0, term=None, specials=()): +def _read_pairs(s, off=0, specials=()): """ Read pairs of lhs=rhs values. off: start offset - term: if True, treat a comma as a terminator for the pairs lists - specials: a lower-cased list of keys that may contain commas if term is - True + specials: a lower-cased list of keys that may contain commas """ vals = [] while 1: @@ -105,17 +103,11 @@ def _read_pairs(s, off=0, term=None, specials=()): rhs = None if off < len(s): if s[off] == "=": - if term and lhs.lower() not in specials: - delims = ";," - else: - delims = ";" - rhs, off = _read_value(s, off+1, delims) + rhs, off = _read_value(s, off+1, ";") vals.append([lhs, rhs]) off += 1 if not off < len(s): break - if term and s[off-1] == ",": - break return vals, off diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py index c0e5a5b76..ad509254f 100644 --- a/test/test_http_cookies.py +++ b/test/test_http_cookies.py @@ -155,6 +155,12 @@ def test_parse_set_cookie_pairs(): ["one", "uno"] ] ], + [ + "one=un\x20", + [ + ["one", "un\x20"] + ] + ], [ "one=uno; foo", [ From d739882bf2dc65925c001c5bf848f5664640d299 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 13:50:57 +1200 Subject: [PATCH 10/15] Add an .extend method for ODicts --- netlib/odict.py | 6 ++++++ test/test_odict.py | 7 ++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/netlib/odict.py b/netlib/odict.py index 7a2f611b2..7a54f282f 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -108,6 +108,12 @@ class ODict(object): lst = copy.deepcopy(self.lst) return self.__class__(lst) + def extend(self, other): + """ + Add the contents of other, preserving any duplicates. + """ + self.lst.extend(other.lst) + def __repr__(self): elements = [] for itm in self.lst: diff --git a/test/test_odict.py b/test/test_odict.py index d90bc6e56..c2415b6d3 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -109,6 +109,12 @@ class TestODict: assert self.od.get_first("one") == "two" assert self.od.get_first("two") == None + def test_extend(self): + a = odict.ODict([["a", "b"], ["c", "d"]]) + b = odict.ODict([["a", "b"], ["e", "f"]]) + a.extend(b) + assert len(a) == 4 + assert a["a"] == ["b", "b"] class TestODictCaseless: def setUp(self): @@ -144,4 +150,3 @@ class TestODictCaseless: assert self.od.keys() == ["foo"] self.od.add("bar", 2) assert len(self.od.keys()) == 2 - From aeebf31927eb3ff74824525005c7b146024de6d5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 16:20:02 +1200 Subject: [PATCH 11/15] odict: don't convert values to strings when added --- netlib/odict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/odict.py b/netlib/odict.py index 7a54f282f..a0ea9e53b 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -84,7 +84,7 @@ class ODict(object): return False def add(self, key, value): - self.lst.append([key, str(value)]) + self.lst.append([key, value]) def get(self, k, d=None): if k in self: @@ -117,7 +117,7 @@ class ODict(object): def __repr__(self): elements = [] for itm in self.lst: - elements.append(itm[0] + ": " + itm[1]) + elements.append(itm[0] + ": " + str(itm[1])) elements.append("") return "\r\n".join(elements) From 0c85c72dc43d0d017e2bf5af9c2def46968d0499 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 15 Apr 2015 10:28:17 +1200 Subject: [PATCH 12/15] ODict improvements - Setting values now tries to preserve the existing order, rather than just appending to the end. - __repr__ now returns a repr of the tuple list. The old repr becomes a .format() method. This is clearer, makes troubleshooting easier, and doesn't assume all data in ODicts are header-like --- netlib/odict.py | 25 +++++++++++++++++++------ netlib/wsgi.py | 29 ++++++++++++++++++----------- test/test_http.py | 11 +++++++++-- test/test_http_cookies.py | 15 +++------------ test/test_odict.py | 25 +++++++++++++++++++++++-- test/test_wsgi.py | 1 - 6 files changed, 72 insertions(+), 34 deletions(-) diff --git a/netlib/odict.py b/netlib/odict.py index a0ea9e53b..dd738c55d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -13,7 +13,8 @@ def safe_subn(pattern, repl, target, *args, **kwargs): class ODict(object): """ - A dictionary-like object for managing ordered (key, value) data. + A dictionary-like object for managing ordered (key, value) data. Think + about it as a convenient interface to a list of (key, value) tuples. """ def __init__(self, lst=None): self.lst = lst or [] @@ -64,11 +65,20 @@ class ODict(object): key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") - - new = self._filter_lst(k, self.lst) - for i in valuelist: - new.append([k, i]) + raise ValueError( + "Expected list of values instead of string. " + "Example: odict['Host'] = ['www.example.com']" + ) + kc = self._kconv(k) + new = [] + for i in self.lst: + if self._kconv(i[0]) == kc: + if valuelist: + new.append([k, valuelist.pop(0)]) + else: + new.append(i) + while valuelist: + new.append([k, valuelist.pop(0)]) self.lst = new def __delitem__(self, k): @@ -115,6 +125,9 @@ class ODict(object): self.lst.extend(other.lst) def __repr__(self): + return repr(self.lst) + + def format(self): elements = [] for itm in self.lst: elements.append(itm[0] + ": " + str(itm[1])) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index bac27d5af..1b9796081 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,5 +1,8 @@ from __future__ import (absolute_import, print_function, division) -import cStringIO, urllib, time, traceback +import cStringIO +import urllib +import time +import traceback from . import odict, tcp @@ -23,15 +26,18 @@ class Request(object): def date_time_string(): """Return the current date and time formatted for a message header.""" WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - MONTHS = [None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + MONTHS = [ + None, + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' + ] now = time.time() year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - WEEKS[wd], - day, MONTHS[month], year, - hh, mm, ss) + WEEKS[wd], + day, MONTHS[month], year, + hh, mm, ss + ) return s @@ -100,6 +106,7 @@ class WSGIAdaptor(object): status = None, headers = None ) + def write(data): if not state["headers_sent"]: soc.write("HTTP/1.1 %s\r\n"%state["status"]) @@ -108,7 +115,7 @@ class WSGIAdaptor(object): h["Server"] = [self.sversion] if 'date' not in h: h["Date"] = [date_time_string()] - soc.write(str(h)) + soc.write(h.format()) soc.write("\r\n") state["headers_sent"] = True if data: @@ -130,7 +137,9 @@ class WSGIAdaptor(object): errs = cStringIO.StringIO() try: - dataiter = self.app(self.make_environ(request, errs, **env), start_response) + dataiter = self.app( + self.make_environ(request, errs, **env), start_response + ) for i in dataiter: write(i) if not state["headers_sent"]: @@ -143,5 +152,3 @@ class WSGIAdaptor(object): except Exception: # pragma: no cover pass return errs.getvalue() - - diff --git a/test/test_http.py b/test/test_http.py index fed609464..b1c62458b 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -53,6 +53,7 @@ def test_connection_close(): h["connection"] = ["close"] assert http.connection_close((1, 1), h) + def test_get_header_tokens(): h = odict.ODictCaseless() assert http.get_header_tokens(h, "foo") == [] @@ -69,11 +70,13 @@ def test_read_http_body_request(): r = cStringIO.StringIO("testing") assert http.read_http_body(r, h, None, "GET", None, True) == "" + def test_read_http_body_response(): h = odict.ODictCaseless() s = cStringIO.StringIO("testing") assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" + def test_read_http_body(): # test default case h = odict.ODictCaseless() @@ -115,6 +118,7 @@ def test_read_http_body(): s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + def test_expected_http_body_size(): # gibber in the content-length field h = odict.ODictCaseless() @@ -135,6 +139,7 @@ def test_expected_http_body_size(): h = odict.ODictCaseless() assert http.expected_http_body_size(h, True, "GET", None) == 0 + def test_parse_http_protocol(): assert http.parse_http_protocol("HTTP/1.1") == (1, 1) assert http.parse_http_protocol("HTTP/0.0") == (0, 0) @@ -189,6 +194,7 @@ def test_parse_init_http(): assert not http.parse_init_http("GET /test foo/1.1") assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") + class TestReadHeaders: def _read(self, data, verbatim=False): if not verbatim: @@ -251,11 +257,12 @@ class TestReadResponseNoContentLength(test.ServerTestBase): httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) assert content == "bar\r\n\r\n" + def test_read_response(): def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) r = cStringIO.StringIO(data) - return http.read_response(r, method, limit, include_body=include_body) + return http.read_response(r, method, limit, include_body = include_body) tutils.raises("server disconnect", tst, "", "GET", None) tutils.raises("invalid server response", tst, "foo", "GET", None) @@ -351,6 +358,7 @@ def test_parse_url(): # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt assert not http.parse_url('http://lo[calhost') + def test_parse_http_basic_auth(): vals = ("basic", "foo", "bar") assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals @@ -358,4 +366,3 @@ def test_parse_http_basic_auth(): assert not http.parse_http_basic_auth("foo bar") v = "basic " + binascii.b2a_base64("foo") assert not http.parse_http_basic_auth(v) - diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py index ad509254f..7438af7ca 100644 --- a/test/test_http_cookies.py +++ b/test/test_http_cookies.py @@ -135,18 +135,6 @@ def test_cookie_roundtrips(): nose.tools.eq_(ret.lst, lst) -# TODO -# I've seen the following pathological cookie in the wild: -# -# cid=09,0,0,0,0; expires=Wed, 10-Jun-2015 21:54:53 GMT; path=/ -# -# It's not compliant under any RFC - the latest RFC prohibits commas in cookie -# values completely, earlier RFCs require them to be within a quoted string. -# -# If we ditch support for earlier RFCs, we can handle this correctly. This -# leaves us with the question: what's more common, multiple-value Set-Cookie -# headers, or Set-Cookie headers that violate the standards? - def test_parse_set_cookie_pairs(): pairs = [ [ @@ -205,6 +193,9 @@ def test_parse_set_cookie_header(): [ "", None ], + [ + ";", None + ], [ "one=uno", ("one", "uno", []) diff --git a/test/test_odict.py b/test/test_odict.py index c2415b6d3..c01c4dbe4 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -6,6 +6,11 @@ class TestODict: def setUp(self): self.od = odict.ODict() + def test_repr(self): + h = odict.ODict() + h["one"] = ["two"] + assert repr(h) + def test_str_err(self): h = odict.ODict() tutils.raises(ValueError, h.__setitem__, "key", "foo") @@ -20,7 +25,7 @@ class TestODict: "two: tre\r\n", "\r\n" ] - out = repr(self.od) + out = self.od.format() for i in expected: assert out.find(i) >= 0 @@ -39,7 +44,7 @@ class TestODict: self.od["one"] = ["uno"] expected1 = "one: uno\r\n" expected2 = "\r\n" - out = repr(self.od) + out = self.od.format() assert out.find(expected1) >= 0 assert out.find(expected2) >= 0 @@ -150,3 +155,19 @@ class TestODictCaseless: assert self.od.keys() == ["foo"] self.od.add("bar", 2) assert len(self.od.keys()) == 2 + + def test_add_order(self): + od = odict.ODict( + [ + ["one", "uno"], + ["two", "due"], + ["three", "tre"], + ] + ) + od["two"] = ["foo", "bar"] + assert od.lst == [ + ["one", "uno"], + ["two", "foo"], + ["three", "tre"], + ["two", "bar"], + ] diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 6e1fb146a..1c8c52635 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -100,4 +100,3 @@ class TestWSGI: start_response(status, response_headers, ei) yield "bbb" assert "Internal Server Error" in self._serve(app) - From c53d89fd7fad6c46458ab3d0140528e344de605f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 16 Apr 2015 08:30:54 +1200 Subject: [PATCH 13/15] Improve flexibility of http_cookies._format_pairs --- netlib/http_cookies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 297efb80d..dab95ed05 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -124,7 +124,7 @@ def _has_special(s): ESCAPE = re.compile(r"([\"\\])") -def _format_pairs(lst, specials=()): +def _format_pairs(lst, specials=(), sep="; "): """ specials: A lower-cased list of keys that will not be quoted. """ @@ -137,7 +137,7 @@ def _format_pairs(lst, specials=()): v = ESCAPE.sub(r"\\\1", v) v = '"%s"'%v vals.append("%s=%s"%(k, v)) - return "; ".join(vals) + return sep.join(vals) def _format_set_cookie_pairs(lst): From 488c25d812a321f5a03253b62ab33b61ecc13de1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 17 Apr 2015 13:57:39 +1200 Subject: [PATCH 14/15] websockets: whitespace, PEP8 --- netlib/websockets/websockets.py | 169 ++++++++++++++++++-------------- 1 file changed, 96 insertions(+), 73 deletions(-) diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index ea3db21d5..8782ea496 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -1,31 +1,34 @@ from __future__ import absolute_import -from base64 import b64encode -from hashlib import sha1 -from mimetools import Message -from netlib import tcp -from netlib import utils -from StringIO import StringIO +import base64 +import hashlib +import mimetools +import StringIO import os -import SocketServer import struct import io -# Colleciton of utility functions that implement small portions of the RFC6455 WebSockets Protocol -# Useful for building WebSocket clients and servers. +from .. import utils + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. # -# Emphassis is on readabilty, simplicity and modularity, not performance or completeness -# -# This is a work in progress and does not yet contain all the utilites need to create fully complient client/servers +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness # +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # # Spec: https://tools.ietf.org/html/rfc6455 -# The magic sha that websocket servers must know to prove they understand RFC6455 +# The magic sha that websocket servers must know to prove they understand +# RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + class WebSocketFrameValidationException(Exception): pass + class WebSocketsFrame(object): """ Represents one websockets frame. @@ -33,7 +36,7 @@ class WebSocketsFrame(object): from_bytes() is also avaliable. WebSockets Frame as defined in RFC6455 - + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-------+-+-------------+-------------------------------+ |F|R|R|R| opcode|M| Payload len | Extended payload length | @@ -62,7 +65,7 @@ class WebSocketsFrame(object): rsv1 = 0, # decimal integer 1 or 0 rsv2 = 0, # decimal integer 1 or 0 rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring + payload = None, # bytestring masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer ): @@ -81,18 +84,17 @@ class WebSocketsFrame(object): @classmethod def from_bytes(cls, bytestring): """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_byte_stream() directly - """ + Construct a websocket frame from an in-memory bytestring to construct + a frame from a stream of bytes, use from_byte_stream() directly + """ return cls.from_byte_stream(io.BytesIO(bytestring).read) - @classmethod def default(cls, message, from_client = False): """ - Construct a basic websocket frame from some default values. + Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. - """ + """ length_code, actual_length = get_payload_length_pair(message) if from_client: @@ -103,7 +105,7 @@ class WebSocketsFrame(object): mask_bit = 0 masking_key = None payload = message - + return cls( fin = 1, # final frame opcode = 1, # text @@ -117,10 +119,10 @@ class WebSocketsFrame(object): def is_valid(self): """ - Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame - has not been corrupted. - """ - try: + Validate websocket frame invariants, call at anytime to ensure the + WebSocketsFrame has not been corrupted. + """ + try: assert 0 <= self.fin <= 1 assert 0 <= self.rsv1 <= 1 assert 0 <= self.rsv2 <= 1 @@ -128,18 +130,18 @@ class WebSocketsFrame(object): assert 1 <= self.opcode <= 4 assert 0 <= self.mask_bit <= 1 assert 1 <= self.payload_length_code <= 127 - + if self.mask_bit == 1: assert 1 <= len(self.masking_key) <= 4 else: - assert self.masking_key == None - + assert self.masking_key is None + assert self.actual_payload_length == len(self.payload) if self.payload is not None and self.masking_key is not None: assert apply_mask(self.payload, self.masking_key) == self.decoded_payload - return True + return True except AssertionError: return False @@ -165,30 +167,32 @@ class WebSocketsFrame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring - If you haven't checked is_valid_frame() then there's no guarentees that the - serialized bytes will be correct. see safe_to_bytes() - """ + Serialize the frame back into the wire format, returns a bytestring If + you haven't checked is_valid_frame() then there's no guarentees that + the serialized bytes will be correct. see safe_to_bytes() + """ max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) - # break down of the bit-math used to construct the first byte from the frame's integer values - # first shift the significant bit into the correct position + # break down of the bit-math used to construct the first byte from the + # frame's integer values first shift the significant bit into the + # correct position # 00000001 << 7 = 10000000 # ... # then combine: - # + # # 10000000 fin # 01000000 res1 # 00100000 res2 # 00010000 res3 # 00000001 opcode - # -------- OR + # -------- OR # 11110001 = first_byte - first_byte = (self.fin << 7) | (self.rsv1 << 6) | (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode - + first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ + (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + second_byte = (self.mask_bit << 7) | self.payload_length_code bytes = chr(first_byte) + chr(second_byte) @@ -199,11 +203,13 @@ class WebSocketsFrame(object): elif self.actual_payload_length < max_16_bit_int: # '!H' pack as 16 bit unsigned short - bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length - + # add 2 byte extended payload length + bytes += struct.pack('!H', self.actual_payload_length) + elif self.actual_payload_length < max_64_bit_int: # '!Q' = pack as 64 bit unsigned long long - bytes += struct.pack('!Q', self.actual_payload_length) # add 8 bytes extended payload length + # add 8 bytes extended payload length + bytes += struct.pack('!Q', self.actual_payload_length) if self.masking_key is not None: bytes += self.masking_key @@ -212,43 +218,46 @@ class WebSocketsFrame(object): return bytes - @classmethod def from_byte_stream(cls, read_bytes): """ read a websockets frame sent by a server or client - + read_bytes is a function that can be backed - by sockets or by any byte reader. So this + by sockets or by any byte reader. So this function may be used to read frames from disk/wire/memory - """ - first_byte = utils.bytes_to_int(read_bytes(1)) + """ + first_byte = utils.bytes_to_int(read_bytes(1)) second_byte = utils.bytes_to_int(read_bytes(1)) - - fin = first_byte >> 7 # grab the left most bit - opcode = first_byte & 15 # grab right most 4 bits by and-ing with 00001111 - mask_bit = second_byte >> 7 # grab left most bit - payload_length = second_byte & 127 # grab the next 7 bits + + # grab the left most bit + fin = first_byte >> 7 + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + payload_length = second_byte & 127 # payload_lengthy > 125 indicates you need to read more bytes # to get the actual payload length if payload_length <= 125: - actual_payload_length = payload_length + actual_payload_length = payload_length elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(read_bytes(2)) + actual_payload_length = utils.bytes_to_int(read_bytes(2)) - elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(read_bytes(8)) + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) # masking key only present if mask bit set if mask_bit == 1: masking_key = read_bytes(4) else: masking_key = None - + payload = read_bytes(actual_payload_length) - + if mask_bit == 1: decoded_payload = apply_mask(payload, masking_key) else: @@ -295,12 +304,15 @@ def apply_mask(message, masking_key): result += chr(ord(char) ^ masks[len(result) % 4]) return result + def random_masking_key(): return os.urandom(4) + def create_client_handshake(host, port, key, version, resource): """ - WebSockets connections are intiated by the client with a valid HTTP upgrade request + WebSockets connections are intiated by the client with a valid HTTP + upgrade request """ headers = [ ('Host', '%s:%s' % (host, port)), @@ -312,10 +324,11 @@ def create_client_handshake(host, port, key, version, resource): request = "GET %s HTTP/1.1" % resource return build_handshake(headers, request) + def create_server_handshake(key): """ - The server response is a valid HTTP 101 response. - """ + The server response is a valid HTTP 101 response. + """ headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), @@ -332,12 +345,13 @@ def build_handshake(headers, request): handshake.append(b'\r\n') return b'\r\n'.join(handshake) + def read_handshake(read_bytes, num_bytes_per_read): """ - From provided function that reads bytes, read in a + From provided function that reads bytes, read in a complete HTTP request, which terminates with a CLRF - """ - response = b'' + """ + response = b'' doubleCLRF = b'\r\n\r\n' while True: bytes = read_bytes(num_bytes_per_read) @@ -348,14 +362,15 @@ def read_handshake(read_bytes, num_bytes_per_read): break return response + def get_payload_length_pair(payload_bytestring): """ A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is larger - than 125 - """ + extended length code to represent the actual length if length code is + larger than 125 + """ actual_length = len(payload_bytestring) - + if actual_length <= 125: length_code = actual_length elif actual_length >= 126 and actual_length <= 65535: @@ -364,6 +379,7 @@ def get_payload_length_pair(payload_bytestring): length_code = 127 return (length_code, actual_length) + def process_handshake_from_client(handshake): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": @@ -371,6 +387,7 @@ def process_handshake_from_client(handshake): key = headers['Sec-WebSocket-Key'] return key + def process_handshake_from_server(handshake, client_nounce): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": @@ -378,12 +395,18 @@ def process_handshake_from_server(handshake, client_nounce): key = headers['Sec-WebSocket-Accept'] return key + def headers_from_http_message(http_message): - return Message(StringIO(http_message.split('\r\n', 1)[1])) + return mimetools.Message( + StringIO.StringIO(http_message.split('\r\n', 1)[1]) + ) + def create_server_nounce(client_nounce): - return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + return base64.b64encode( + hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') + ) + def create_client_nounce(): - return b64encode(os.urandom(16)).decode('utf-8') - + return base64.b64encode(os.urandom(16)).decode('utf-8') From 7defb5be862a4251da9d7c530593f7e9be3e739e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 17 Apr 2015 14:29:20 +1200 Subject: [PATCH 15/15] websockets: more whitespace, WebSocketFrame -> Frame --- netlib/websockets/implementations.py | 12 ++-- netlib/websockets/websockets.py | 100 +++++++++++++-------------- test/test_websockets.py | 45 +++++++----- 3 files changed, 81 insertions(+), 76 deletions(-) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 1ded3b857..337c54964 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -9,7 +9,7 @@ import os # Simple websocket client and servers that are used to exercise the functionality in websockets.py # These are *not* fully RFC6455 compliant -class WebSocketsEchoHandler(tcp.BaseHandler): +class WebSocketsEchoHandler(tcp.BaseHandler): def __init__(self, connection, address, server): super(WebSocketsEchoHandler, self).__init__(connection, address, server) self.handshake_done = False @@ -22,14 +22,14 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload self.on_message(decoded) def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = False) + frame = ws.Frame.default(message, from_client = False) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() - + def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) key = ws.process_handshake_from_client(client_hs) @@ -72,9 +72,9 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + return ws.Frame.from_byte_stream(self.rfile.read).payload def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = True) + frame = ws.Frame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 8782ea496..86d98cafd 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -29,7 +29,7 @@ class WebSocketFrameValidationException(Exception): pass -class WebSocketsFrame(object): +class Frame(object): """ Represents one websockets frame. Constructor takes human readable forms of the frame components @@ -98,29 +98,29 @@ class WebSocketsFrame(object): length_code, actual_length = get_payload_length_pair(message) if from_client: - mask_bit = 1 + mask_bit = 1 masking_key = random_masking_key() - payload = apply_mask(message, masking_key) + payload = apply_mask(message, masking_key) else: - mask_bit = 0 + mask_bit = 0 masking_key = None - payload = message + payload = message return cls( - fin = 1, # final frame - opcode = 1, # text - mask_bit = mask_bit, - payload_length_code = length_code, - payload = payload, - masking_key = masking_key, - decoded_payload = message, + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, actual_payload_length = actual_length ) def is_valid(self): """ - Validate websocket frame invariants, call at anytime to ensure the - WebSocketsFrame has not been corrupted. + Validate websocket frame invariants, call at anytime to ensure the + Frame has not been corrupted. """ try: assert 0 <= self.fin <= 1 @@ -147,17 +147,18 @@ class WebSocketsFrame(object): def human_readable(self): return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length))]) + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)) + ]) def safe_to_bytes(self): if self.is_valid(): @@ -167,11 +168,10 @@ class WebSocketsFrame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring If - you haven't checked is_valid_frame() then there's no guarentees that - the serialized bytes will be correct. see safe_to_bytes() + Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees + that the serialized bytes will be correct. see safe_to_bytes() """ - max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) @@ -199,13 +199,10 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < max_16_bit_int: - # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < max_64_bit_int: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length @@ -215,7 +212,6 @@ class WebSocketsFrame(object): bytes += self.masking_key bytes += self.payload # already will be encoded if neccessary - return bytes @classmethod @@ -264,29 +260,31 @@ class WebSocketsFrame(object): decoded_payload = payload return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, - decoded_payload = decoded_payload, + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, actual_payload_length = actual_payload_length ) def __eq__(self, other): return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.payload_length_code == other.payload_length_code and - self.masking_key == other.masking_key and - self.payload == other.payload and - self.decoded_payload == other.decoded_payload and - self.actual_payload_length == other.actual_payload_length) + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length + ) + def apply_mask(message, masking_key): """ diff --git a/test/test_websockets.py b/test/test_websockets.py index 951aa41ff..d17536383 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -5,6 +5,7 @@ from netlib.websockets import websockets as ws import os from nose.tools import raises + class TestWebSockets(test.ServerTestBase): handler = impl.WebSocketsEchoHandler @@ -22,9 +23,12 @@ class TestWebSockets(test.ServerTestBase): self.echo("hello I'm the client") def test_frame_sizes(self): - small_msg = self.random_bytes(100) # length can fit in the the 7 bit payload length - medium_msg = self.random_bytes(50000) # 50kb, sligthly larger than can fit in a 7 bit int - large_msg = self.random_bytes(150000) # 150kb, slightly larger than can fit in a 16 bit int + # length can fit in the the 7 bit payload length + small_msg = self.random_bytes(100) + # 50kb, sligthly larger than can fit in a 7 bit int + medium_msg = self.random_bytes(50000) + # 150kb, slightly larger than can fit in a 16 bit int + large_msg = self.random_bytes(150000) self.echo(small_msg) self.echo(medium_msg) @@ -33,51 +37,54 @@ class TestWebSockets(test.ServerTestBase): def test_default_builder(self): """ default builder should always generate valid frames - """ + """ msg = self.random_bytes() - client_frame = ws.WebSocketsFrame.default(msg, from_client = True) + client_frame = ws.Frame.default(msg, from_client = True) assert client_frame.is_valid() - server_frame = ws.WebSocketsFrame.default(msg, from_client = False) + server_frame = ws.Frame.default(msg, from_client = False) assert server_frame.is_valid() def test_serialization_bijection(self): """ - Ensure that various frame types can be serialized/deserialized back and forth - between to_bytes() and from_bytes() - """ + Ensure that various frame types can be serialized/deserialized back + and forth between to_bytes() and from_bytes() + """ for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = ws.WebSocketsFrame.default(self.random_bytes(num_bytes), is_client) - assert frame == ws.WebSocketsFrame.from_bytes(frame.to_bytes()) + for num_bytes in [100, 50000, 150000]: + frame = ws.Frame.default( + self.random_bytes(num_bytes), is_client + ) + assert frame == ws.Frame.from_bytes(frame.to_bytes()) bytes = b'\x81\x11cba' - assert ws.WebSocketsFrame.from_bytes(bytes).to_bytes() == bytes + assert ws.Frame.from_bytes(bytes).to_bytes() == bytes @raises(ws.WebSocketFrameValidationException) def test_safe_to_bytes(self): - frame = ws.WebSocketsFrame.default(self.random_bytes(8)) - frame.actual_payload_length = 1 #corrupt the frame + frame = ws.Frame.default(self.random_bytes(8)) + frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() class BadHandshakeHandler(impl.WebSocketsEchoHandler): def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.process_handshake_from_client(client_hs) - response = ws.create_server_handshake("malformed_key") + ws.process_handshake_from_client(client_hs) + response = ws.create_server_handshake("malformed_key") self.wfile.write(response) self.wfile.flush() self.handshake_done = True + class TestBadHandshake(test.ServerTestBase): """ Ensure that the client disconnects if the server handshake is malformed - """ + """ handler = BadHandshakeHandler @raises(tcp.NetLibDisconnect) def test(self): client = impl.WebSocketsClient(("127.0.0.1", self.port)) client.connect() - client.send_message("hello") \ No newline at end of file + client.send_message("hello")