diff --git a/.env b/.env new file mode 100644 index 000000000..7f847e29f --- /dev/null +++ b/.env @@ -0,0 +1,5 @@ +DIR=`dirname $0` +if [ -z "$VIRTUAL_ENV" ] && [ -f $DIR/../venv.mitmproxy/bin/activate ]; then + echo "Activating mitmproxy virtualenv..." + source $DIR/../venv.mitmproxy/bin/activate +fi diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py new file mode 100644 index 000000000..dab95ed05 --- /dev/null +++ b/netlib/http_cookies.py @@ -0,0 +1,197 @@ +""" +A flexible module for cookie parsing and manipulation. + +This module differs from usual standards-compliant cookie modules in a number of +ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple cookies +to be set in a single header. Technically this should be feasible, but it turns +out that violations of RFC6265 that makes the parsing problem indeterminate are +much more common than genuine occurences of the multi-cookie variants. +Serialization follows RFC6265. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 +""" + +# TODO +# - Disallow LHS-only Cookie values + +import re + +import odict + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start+1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i+1], i+1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + for i in range(start+1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + pass + else: + ret.append(s[i]) + return "".join(ret), i+1 + + +def _read_value(s, start, delims): + """ + Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. + """ + if start >= len(s): + return "", start + elif s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, delims) + + +def _read_pairs(s, off=0, specials=()): + """ + Read pairs of lhs=rhs values. + + off: start offset + specials: a lower-cased list of keys that may contain commas + """ + vals = [] + while 1: + lhs, off = _read_token(s, off) + lhs = lhs.lstrip() + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off+1, ";") + vals.append([lhs, rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +ESCAPE = re.compile(r"([\"\\])") + + +def _format_pairs(lst, specials=(), sep="; "): + """ + specials: A lower-cased list of keys that will not be quoted. + """ + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"'%v + vals.append("%s=%s"%(k, v)) + return sep.join(vals) + + +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials = ("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): + """ + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. + """ + pairs, off = _read_pairs( + s, + specials = ("expires", "path") + ) + return pairs + + +def parse_set_cookie_header(str): + """ + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. + """ + pairs = _parse_set_cookie_pairs(str) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) + + +def parse_cookie_header(str): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off = _read_pairs(str) + return odict.ODict(pairs) + + +def format_cookie_header(od): + """ + Formats a Cookie header value. + """ + return _format_pairs(od.lst) diff --git a/netlib/odict.py b/netlib/odict.py index 7a2f611b2..dd738c55d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -13,7 +13,8 @@ def safe_subn(pattern, repl, target, *args, **kwargs): class ODict(object): """ - A dictionary-like object for managing ordered (key, value) data. + A dictionary-like object for managing ordered (key, value) data. Think + about it as a convenient interface to a list of (key, value) tuples. """ def __init__(self, lst=None): self.lst = lst or [] @@ -64,11 +65,20 @@ class ODict(object): key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") - - new = self._filter_lst(k, self.lst) - for i in valuelist: - new.append([k, i]) + raise ValueError( + "Expected list of values instead of string. " + "Example: odict['Host'] = ['www.example.com']" + ) + kc = self._kconv(k) + new = [] + for i in self.lst: + if self._kconv(i[0]) == kc: + if valuelist: + new.append([k, valuelist.pop(0)]) + else: + new.append(i) + while valuelist: + new.append([k, valuelist.pop(0)]) self.lst = new def __delitem__(self, k): @@ -84,7 +94,7 @@ class ODict(object): return False def add(self, key, value): - self.lst.append([key, str(value)]) + self.lst.append([key, value]) def get(self, k, d=None): if k in self: @@ -108,10 +118,19 @@ class ODict(object): lst = copy.deepcopy(self.lst) return self.__class__(lst) + def extend(self, other): + """ + Add the contents of other, preserving any duplicates. + """ + self.lst.extend(other.lst) + def __repr__(self): + return repr(self.lst) + + def format(self): elements = [] for itm in self.lst: - elements.append(itm[0] + ": " + itm[1]) + elements.append(itm[0] + ": " + str(itm[1])) elements.append("") return "\r\n".join(elements) diff --git a/netlib/utils.py b/netlib/utils.py index 79077ac60..03a70977c 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -8,6 +8,9 @@ def isascii(s): return False return True +# best way to do it in python 2.x +def bytes_to_int(i): + return int(i.encode('hex'), 16) def cleanBin(s, fixspacing=False): """ diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 000000000..9b4faa337 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py new file mode 100644 index 000000000..337c54964 --- /dev/null +++ b/netlib/websockets/implementations.py @@ -0,0 +1,80 @@ +from netlib import tcp +from base64 import b64encode +from StringIO import StringIO +from . import websockets as ws +import struct +import SocketServer +import os + +# Simple websocket client and servers that are used to exercise the functionality in websockets.py +# These are *not* fully RFC6455 compliant + +class WebSocketsEchoHandler(tcp.BaseHandler): + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__(connection, address, server) + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload + self.on_message(decoded) + + def send_message(self, message): + frame = ws.Frame.default(message, from_client = False) + self.wfile.write(frame.safe_to_bytes()) + self.wfile.flush() + + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + key = ws.process_handshake_from_client(client_hs) + response = ws.create_server_handshake(key) + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.version = "13" + self.client_nounce = ws.create_client_nounce() + self.resource = "/" + + def connect(self): + super(WebSocketsClient, self).connect() + + handshake = ws.create_client_handshake( + self.address.host, + self.address.port, + self.client_nounce, + self.version, + self.resource + ) + + self.wfile.write(handshake) + self.wfile.flush() + + server_handshake = ws.read_handshake(self.rfile.read, 1) + + server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) + + if not server_nounce == ws.create_server_nounce(self.client_nounce): + self.close() + + def read_next_message(self): + return ws.Frame.from_byte_stream(self.rfile.read).payload + + def send_message(self, message): + frame = ws.Frame.default(message, from_client = True) + self.wfile.write(frame.safe_to_bytes()) + self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py new file mode 100644 index 000000000..86d98cafd --- /dev/null +++ b/netlib/websockets/websockets.py @@ -0,0 +1,410 @@ +from __future__ import absolute_import + +import base64 +import hashlib +import mimetools +import StringIO +import os +import struct +import io + +from .. import utils + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness +# +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand +# RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + + +class WebSocketFrameValidationException(Exception): + pass + + +class Frame(object): + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + def __init__( + self, + fin, # decmial integer 1 or 0 + opcode, # decmial integer 1 - 4 + mask_bit, # decimal integer 1 or 0 + payload_length_code, # decimal integer 1 - 127 + decoded_payload, # bytestring + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string + actual_payload_length = None, # any decimal integer + ): + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload + self.actual_payload_length = actual_payload_length + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring to construct + a frame from a stream of bytes, use from_byte_stream() directly + """ + return cls.from_byte_stream(io.BytesIO(bytestring).read) + + @classmethod + def default(cls, message, from_client = False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + length_code, actual_length = get_payload_length_pair(message) + + if from_client: + mask_bit = 1 + masking_key = random_masking_key() + payload = apply_mask(message, masking_key) + else: + mask_bit = 0 + masking_key = None + payload = message + + return cls( + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, + actual_payload_length = actual_length + ) + + def is_valid(self): + """ + Validate websocket frame invariants, call at anytime to ensure the + Frame has not been corrupted. + """ + try: + assert 0 <= self.fin <= 1 + assert 0 <= self.rsv1 <= 1 + assert 0 <= self.rsv2 <= 1 + assert 0 <= self.rsv3 <= 1 + assert 1 <= self.opcode <= 4 + assert 0 <= self.mask_bit <= 1 + assert 1 <= self.payload_length_code <= 127 + + if self.mask_bit == 1: + assert 1 <= len(self.masking_key) <= 4 + else: + assert self.masking_key is None + + assert self.actual_payload_length == len(self.payload) + + if self.payload is not None and self.masking_key is not None: + assert apply_mask(self.payload, self.masking_key) == self.decoded_payload + + return True + except AssertionError: + return False + + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)) + ]) + + def safe_to_bytes(self): + if self.is_valid(): + return self.to_bytes() + else: + raise WebSocketFrameValidationException() + + def to_bytes(self): + """ + Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees + that the serialized bytes will be correct. see safe_to_bytes() + """ + max_16_bit_int = (1 << 16) + max_64_bit_int = (1 << 63) + + # break down of the bit-math used to construct the first byte from the + # frame's integer values first shift the significant bit into the + # correct position + # 00000001 << 7 = 10000000 + # ... + # then combine: + # + # 10000000 fin + # 01000000 res1 + # 00100000 res2 + # 00010000 res3 + # 00000001 opcode + # -------- OR + # 11110001 = first_byte + + first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ + (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + + second_byte = (self.mask_bit << 7) | self.payload_length_code + + bytes = chr(first_byte) + chr(second_byte) + + if self.actual_payload_length < 126: + pass + elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + bytes += struct.pack('!H', self.actual_payload_length) + elif self.actual_payload_length < max_64_bit_int: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + bytes += struct.pack('!Q', self.actual_payload_length) + + if self.masking_key is not None: + bytes += self.masking_key + + bytes += self.payload # already will be encoded if neccessary + return bytes + + @classmethod + def from_byte_stream(cls, read_bytes): + """ + read a websockets frame sent by a server or client + + read_bytes is a function that can be backed + by sockets or by any byte reader. So this + function may be used to read frames from disk/wire/memory + """ + first_byte = utils.bytes_to_int(read_bytes(1)) + second_byte = utils.bytes_to_int(read_bytes(1)) + + # grab the left most bit + fin = first_byte >> 7 + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + payload_length = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if payload_length <= 125: + actual_payload_length = payload_length + + elif payload_length == 126: + actual_payload_length = utils.bytes_to_int(read_bytes(2)) + + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = read_bytes(4) + else: + masking_key = None + + payload = read_bytes(actual_payload_length) + + if mask_bit == 1: + decoded_payload = apply_mask(payload, masking_key) + else: + decoded_payload = payload + + return cls( + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, + actual_payload_length = actual_payload_length + ) + + def __eq__(self, other): + return ( + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length + ) + + +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + + +def random_masking_key(): + return os.urandom(4) + + +def create_client_handshake(host, port, key, version, resource): + """ + WebSockets connections are intiated by the client with a valid HTTP + upgrade request + """ + headers = [ + ('Host', '%s:%s' % (host, port)), + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ] + request = "GET %s HTTP/1.1" % resource + return build_handshake(headers, request) + + +def create_server_handshake(key): + """ + The server response is a valid HTTP 101 response. + """ + headers = [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', create_server_nounce(key)) + ] + request = "HTTP/1.1 101 Switching Protocols" + return build_handshake(headers, request) + + +def build_handshake(headers, request): + handshake = [request.encode('utf-8')] + for header, value in headers: + handshake.append(("%s: %s" % (header, value)).encode('utf-8')) + handshake.append(b'\r\n') + return b'\r\n'.join(handshake) + + +def read_handshake(read_bytes, num_bytes_per_read): + """ + From provided function that reads bytes, read in a + complete HTTP request, which terminates with a CLRF + """ + response = b'' + doubleCLRF = b'\r\n\r\n' + while True: + bytes = read_bytes(num_bytes_per_read) + if not bytes: + break + response += bytes + if doubleCLRF in response: + break + return response + + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + + +def process_handshake_from_client(handshake): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Key'] + return key + + +def process_handshake_from_server(handshake, client_nounce): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Accept'] + return key + + +def headers_from_http_message(http_message): + return mimetools.Message( + StringIO.StringIO(http_message.split('\r\n', 1)[1]) + ) + + +def create_server_nounce(client_nounce): + return base64.b64encode( + hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') + ) + + +def create_client_nounce(): + return base64.b64encode(os.urandom(16)).decode('utf-8') diff --git a/netlib/wsgi.py b/netlib/wsgi.py index bac27d5af..1b9796081 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,5 +1,8 @@ from __future__ import (absolute_import, print_function, division) -import cStringIO, urllib, time, traceback +import cStringIO +import urllib +import time +import traceback from . import odict, tcp @@ -23,15 +26,18 @@ class Request(object): def date_time_string(): """Return the current date and time formatted for a message header.""" WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - MONTHS = [None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + MONTHS = [ + None, + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' + ] now = time.time() year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - WEEKS[wd], - day, MONTHS[month], year, - hh, mm, ss) + WEEKS[wd], + day, MONTHS[month], year, + hh, mm, ss + ) return s @@ -100,6 +106,7 @@ class WSGIAdaptor(object): status = None, headers = None ) + def write(data): if not state["headers_sent"]: soc.write("HTTP/1.1 %s\r\n"%state["status"]) @@ -108,7 +115,7 @@ class WSGIAdaptor(object): h["Server"] = [self.sversion] if 'date' not in h: h["Date"] = [date_time_string()] - soc.write(str(h)) + soc.write(h.format()) soc.write("\r\n") state["headers_sent"] = True if data: @@ -130,7 +137,9 @@ class WSGIAdaptor(object): errs = cStringIO.StringIO() try: - dataiter = self.app(self.make_environ(request, errs, **env), start_response) + dataiter = self.app( + self.make_environ(request, errs, **env), start_response + ) for i in dataiter: write(i) if not state["headers_sent"]: @@ -143,5 +152,3 @@ class WSGIAdaptor(object): except Exception: # pragma: no cover pass return errs.getvalue() - - diff --git a/test/test_http.py b/test/test_http.py index fed609464..b1c62458b 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -53,6 +53,7 @@ def test_connection_close(): h["connection"] = ["close"] assert http.connection_close((1, 1), h) + def test_get_header_tokens(): h = odict.ODictCaseless() assert http.get_header_tokens(h, "foo") == [] @@ -69,11 +70,13 @@ def test_read_http_body_request(): r = cStringIO.StringIO("testing") assert http.read_http_body(r, h, None, "GET", None, True) == "" + def test_read_http_body_response(): h = odict.ODictCaseless() s = cStringIO.StringIO("testing") assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" + def test_read_http_body(): # test default case h = odict.ODictCaseless() @@ -115,6 +118,7 @@ def test_read_http_body(): s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + def test_expected_http_body_size(): # gibber in the content-length field h = odict.ODictCaseless() @@ -135,6 +139,7 @@ def test_expected_http_body_size(): h = odict.ODictCaseless() assert http.expected_http_body_size(h, True, "GET", None) == 0 + def test_parse_http_protocol(): assert http.parse_http_protocol("HTTP/1.1") == (1, 1) assert http.parse_http_protocol("HTTP/0.0") == (0, 0) @@ -189,6 +194,7 @@ def test_parse_init_http(): assert not http.parse_init_http("GET /test foo/1.1") assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") + class TestReadHeaders: def _read(self, data, verbatim=False): if not verbatim: @@ -251,11 +257,12 @@ class TestReadResponseNoContentLength(test.ServerTestBase): httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) assert content == "bar\r\n\r\n" + def test_read_response(): def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) r = cStringIO.StringIO(data) - return http.read_response(r, method, limit, include_body=include_body) + return http.read_response(r, method, limit, include_body = include_body) tutils.raises("server disconnect", tst, "", "GET", None) tutils.raises("invalid server response", tst, "foo", "GET", None) @@ -351,6 +358,7 @@ def test_parse_url(): # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt assert not http.parse_url('http://lo[calhost') + def test_parse_http_basic_auth(): vals = ("basic", "foo", "bar") assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals @@ -358,4 +366,3 @@ def test_parse_http_basic_auth(): assert not http.parse_http_basic_auth("foo bar") v = "basic " + binascii.b2a_base64("foo") assert not http.parse_http_basic_auth(v) - diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py new file mode 100644 index 000000000..7438af7ca --- /dev/null +++ b/test/test_http_cookies.py @@ -0,0 +1,220 @@ +import pprint +import nose.tools + +from netlib import http_cookies, odict + + +def test_read_token(): + tokens = [ + [("foo", 0), ("foo", 3)], + [("foo", 1), ("oo", 3)], + [(" foo", 1), ("foo", 4)], + [(" foo;", 1), ("foo", 4)], + [(" foo=", 1), ("foo", 4)], + [(" foo=bar", 1), ("foo", 4)], + ] + for q, a in tokens: + nose.tools.eq_(http_cookies._read_token(*q), a) + + +def test_read_quoted_string(): + tokens = [ + [('"foo" x', 0), ("foo", 5)], + [('"f\oo" x', 0), ("foo", 6)], + [(r'"f\\o" x', 0), (r"f\o", 6)], + [(r'"f\\" x', 0), (r"f" + '\\', 5)], + [('"fo\\\"" x', 0), ("fo\"", 6)], + ] + for q, a in tokens: + nose.tools.eq_(http_cookies._read_quoted_string(*q), a) + + +def test_read_pairs(): + vals = [ + [ + "one", + [["one", None]] + ], + [ + "one=two", + [["one", "two"]] + ], + [ + "one=", + [["one", ""]] + ], + [ + 'one="two"', + [["one", "two"]] + ], + [ + 'one="two"; three=four', + [["one", "two"], ["three", "four"]] + ], + [ + 'one="two"; three=four; five', + [["one", "two"], ["three", "four"], ["five", None]] + ], + [ + 'one="\\"two"; three=four', + [["one", '"two'], ["three", "four"]] + ], + ] + for s, lst in vals: + ret, off = http_cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + + +def test_pairs_roundtrips(): + pairs = [ + [ + "", + [] + ], + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one", + [["one", None]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="uno"; two="\due"', + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="un\\"o"', + [["one", 'un"o']] + ], + [ + 'one="uno,due"', + [["one", 'uno,due']] + ], + [ + "one=uno; two; three=tre", + [["one", "uno"], ["two", None], ["three", "tre"]] + ], + [ + "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " + "_rcc2=53VdltWl+Ov6ordflA==;", + [ + ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], + ["_rcc2", "53VdltWl+Ov6ordflA=="] + ] + ] + ] + for s, lst in pairs: + ret, off = http_cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + s2 = http_cookies._format_pairs(lst) + ret, off = http_cookies._read_pairs(s2) + nose.tools.eq_(ret, lst) + + +def test_cookie_roundtrips(): + pairs = [ + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + ] + for s, lst in pairs: + ret = http_cookies.parse_cookie_header(s) + nose.tools.eq_(ret.lst, lst) + s2 = http_cookies.format_cookie_header(ret) + ret = http_cookies.parse_cookie_header(s2) + nose.tools.eq_(ret.lst, lst) + + +def test_parse_set_cookie_pairs(): + pairs = [ + [ + "one=uno", + [ + ["one", "uno"] + ] + ], + [ + "one=un\x20", + [ + ["one", "un\x20"] + ] + ], + [ + "one=uno; foo", + [ + ["one", "uno"], + ["foo", None] + ] + ], + [ + "mun=1.390.f60; " + "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " + "domain=b.aol.com", + [ + ["mun", "1.390.f60"], + ["expires", "sun, 11-oct-2015 12:38:31 gmt"], + ["path", "/"], + ["domain", "b.aol.com"] + ] + ], + [ + r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' + 'domain=.rubiconproject.com; ' + 'expires=mon, 11-may-2015 21:54:57 gmt; ' + 'path=/', + [ + ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], + ['domain', '.rubiconproject.com'], + ['expires', 'mon, 11-may-2015 21:54:57 gmt'], + ['path', '/'] + ] + ], + ] + for s, lst in pairs: + ret = http_cookies._parse_set_cookie_pairs(s) + nose.tools.eq_(ret, lst) + s2 = http_cookies._format_set_cookie_pairs(ret) + ret2 = http_cookies._parse_set_cookie_pairs(s2) + nose.tools.eq_(ret2, lst) + + +def test_parse_set_cookie_header(): + vals = [ + [ + "", None + ], + [ + ";", None + ], + [ + "one=uno", + ("one", "uno", []) + ], + [ + "one=uno; foo=bar", + ("one", "uno", [["foo", "bar"]]) + ] + ] + for s, expected in vals: + ret = http_cookies.parse_set_cookie_header(s) + if expected: + assert ret[0] == expected[0] + assert ret[1] == expected[1] + nose.tools.eq_(ret[2].lst, expected[2]) + s2 = http_cookies.format_set_cookie_header(*ret) + ret2 = http_cookies.parse_set_cookie_header(s2) + assert ret2[0] == expected[0] + assert ret2[1] == expected[1] + nose.tools.eq_(ret2[2].lst, expected[2]) + else: + assert ret is None diff --git a/test/test_odict.py b/test/test_odict.py index d90bc6e56..c01c4dbe4 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -6,6 +6,11 @@ class TestODict: def setUp(self): self.od = odict.ODict() + def test_repr(self): + h = odict.ODict() + h["one"] = ["two"] + assert repr(h) + def test_str_err(self): h = odict.ODict() tutils.raises(ValueError, h.__setitem__, "key", "foo") @@ -20,7 +25,7 @@ class TestODict: "two: tre\r\n", "\r\n" ] - out = repr(self.od) + out = self.od.format() for i in expected: assert out.find(i) >= 0 @@ -39,7 +44,7 @@ class TestODict: self.od["one"] = ["uno"] expected1 = "one: uno\r\n" expected2 = "\r\n" - out = repr(self.od) + out = self.od.format() assert out.find(expected1) >= 0 assert out.find(expected2) >= 0 @@ -109,6 +114,12 @@ class TestODict: assert self.od.get_first("one") == "two" assert self.od.get_first("two") == None + def test_extend(self): + a = odict.ODict([["a", "b"], ["c", "d"]]) + b = odict.ODict([["a", "b"], ["e", "f"]]) + a.extend(b) + assert len(a) == 4 + assert a["a"] == ["b", "b"] class TestODictCaseless: def setUp(self): @@ -145,3 +156,18 @@ class TestODictCaseless: self.od.add("bar", 2) assert len(self.od.keys()) == 2 + def test_add_order(self): + od = odict.ODict( + [ + ["one", "uno"], + ["two", "due"], + ["three", "tre"], + ] + ) + od["two"] = ["foo", "bar"] + assert od.lst == [ + ["one", "uno"], + ["two", "foo"], + ["three", "tre"], + ["two", "bar"], + ] diff --git a/test/test_websockets.py b/test/test_websockets.py new file mode 100644 index 000000000..d17536383 --- /dev/null +++ b/test/test_websockets.py @@ -0,0 +1,90 @@ +from netlib import tcp +from netlib import test +from netlib.websockets import implementations as impl +from netlib.websockets import websockets as ws +import os +from nose.tools import raises + + +class TestWebSockets(test.ServerTestBase): + handler = impl.WebSocketsEchoHandler + + def random_bytes(self, n = 100): + return os.urandom(n) + + def echo(self, msg): + client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message(msg) + response = client.read_next_message() + assert response == msg + + def test_simple_echo(self): + self.echo("hello I'm the client") + + def test_frame_sizes(self): + # length can fit in the the 7 bit payload length + small_msg = self.random_bytes(100) + # 50kb, sligthly larger than can fit in a 7 bit int + medium_msg = self.random_bytes(50000) + # 150kb, slightly larger than can fit in a 16 bit int + large_msg = self.random_bytes(150000) + + self.echo(small_msg) + self.echo(medium_msg) + self.echo(large_msg) + + def test_default_builder(self): + """ + default builder should always generate valid frames + """ + msg = self.random_bytes() + client_frame = ws.Frame.default(msg, from_client = True) + assert client_frame.is_valid() + + server_frame = ws.Frame.default(msg, from_client = False) + assert server_frame.is_valid() + + def test_serialization_bijection(self): + """ + Ensure that various frame types can be serialized/deserialized back + and forth between to_bytes() and from_bytes() + """ + for is_client in [True, False]: + for num_bytes in [100, 50000, 150000]: + frame = ws.Frame.default( + self.random_bytes(num_bytes), is_client + ) + assert frame == ws.Frame.from_bytes(frame.to_bytes()) + + bytes = b'\x81\x11cba' + assert ws.Frame.from_bytes(bytes).to_bytes() == bytes + + @raises(ws.WebSocketFrameValidationException) + def test_safe_to_bytes(self): + frame = ws.Frame.default(self.random_bytes(8)) + frame.actual_payload_length = 1 # corrupt the frame + frame.safe_to_bytes() + + +class BadHandshakeHandler(impl.WebSocketsEchoHandler): + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + ws.process_handshake_from_client(client_hs) + response = ws.create_server_handshake("malformed_key") + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + + +class TestBadHandshake(test.ServerTestBase): + """ + Ensure that the client disconnects if the server handshake is malformed + """ + handler = BadHandshakeHandler + + @raises(tcp.NetLibDisconnect) + def test(self): + client = impl.WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message("hello") diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 6e1fb146a..1c8c52635 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -100,4 +100,3 @@ class TestWSGI: start_response(status, response_headers, ei) yield "bbb" assert "Internal Server Error" in self._serve(app) -