Merge branch 'master' of github.com:mitmproxy/netlib

This commit is contained in:
Maximilian Hils 2015-04-17 16:29:25 +02:00
commit 08ba987a84
13 changed files with 1088 additions and 24 deletions

5
.env Normal file
View File

@ -0,0 +1,5 @@
DIR=`dirname $0`
if [ -z "$VIRTUAL_ENV" ] && [ -f $DIR/../venv.mitmproxy/bin/activate ]; then
echo "Activating mitmproxy virtualenv..."
source $DIR/../venv.mitmproxy/bin/activate
fi

197
netlib/http_cookies.py Normal file
View File

@ -0,0 +1,197 @@
"""
A flexible module for cookie parsing and manipulation.
This module differs from usual standards-compliant cookie modules in a number of
ways. We try to be as permissive as possible, and to retain even mal-formed
information. Duplicate cookies are preserved in parsing, and can be set in
formatting. We do attempt to escape and quote values where needed, but will not
reject data that violate the specs.
Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do
not parse the comma-separated variant of Set-Cookie that allows multiple cookies
to be set in a single header. Technically this should be feasible, but it turns
out that violations of RFC6265 that makes the parsing problem indeterminate are
much more common than genuine occurences of the multi-cookie variants.
Serialization follows RFC6265.
http://tools.ietf.org/html/rfc6265
http://tools.ietf.org/html/rfc2109
http://tools.ietf.org/html/rfc2965
"""
# TODO
# - Disallow LHS-only Cookie values
import re
import odict
def _read_until(s, start, term):
"""
Read until one of the characters in term is reached.
"""
if start == len(s):
return "", start+1
for i in range(start, len(s)):
if s[i] in term:
return s[start:i], i
return s[start:i+1], i+1
def _read_token(s, start):
"""
Read a token - the LHS of a token/value pair in a cookie.
"""
return _read_until(s, start, ";=")
def _read_quoted_string(s, start):
"""
start: offset to the first quote of the string to be read
A sort of loose super-set of the various quoted string specifications.
RFC6265 disallows backslashes or double quotes within quoted strings.
Prior RFCs use backslashes to escape. This leaves us free to apply
backslash escaping by default and be compatible with everything.
"""
escaping = False
ret = []
# Skip the first quote
for i in range(start+1, len(s)):
if escaping:
ret.append(s[i])
escaping = False
elif s[i] == '"':
break
elif s[i] == "\\":
escaping = True
pass
else:
ret.append(s[i])
return "".join(ret), i+1
def _read_value(s, start, delims):
"""
Reads a value - the RHS of a token/value pair in a cookie.
special: If the value is special, commas are premitted. Else comma
terminates. This helps us support old and new style values.
"""
if start >= len(s):
return "", start
elif s[start] == '"':
return _read_quoted_string(s, start)
else:
return _read_until(s, start, delims)
def _read_pairs(s, off=0, specials=()):
"""
Read pairs of lhs=rhs values.
off: start offset
specials: a lower-cased list of keys that may contain commas
"""
vals = []
while 1:
lhs, off = _read_token(s, off)
lhs = lhs.lstrip()
if lhs:
rhs = None
if off < len(s):
if s[off] == "=":
rhs, off = _read_value(s, off+1, ";")
vals.append([lhs, rhs])
off += 1
if not off < len(s):
break
return vals, off
def _has_special(s):
for i in s:
if i in '",;\\':
return True
o = ord(i)
if o < 0x21 or o > 0x7e:
return True
return False
ESCAPE = re.compile(r"([\"\\])")
def _format_pairs(lst, specials=(), sep="; "):
"""
specials: A lower-cased list of keys that will not be quoted.
"""
vals = []
for k, v in lst:
if v is None:
vals.append(k)
else:
if k.lower() not in specials and _has_special(v):
v = ESCAPE.sub(r"\\\1", v)
v = '"%s"'%v
vals.append("%s=%s"%(k, v))
return sep.join(vals)
def _format_set_cookie_pairs(lst):
return _format_pairs(
lst,
specials = ("expires", "path")
)
def _parse_set_cookie_pairs(s):
"""
For Set-Cookie, we support multiple cookies as described in RFC2109.
This function therefore returns a list of lists.
"""
pairs, off = _read_pairs(
s,
specials = ("expires", "path")
)
return pairs
def parse_set_cookie_header(str):
"""
Parse a Set-Cookie header value
Returns a (name, value, attrs) tuple, or None, where attrs is an
ODictCaseless set of attributes. No attempt is made to parse attribute
values - they are treated purely as strings.
"""
pairs = _parse_set_cookie_pairs(str)
if pairs:
return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:])
def format_set_cookie_header(name, value, attrs):
"""
Formats a Set-Cookie header value.
"""
pairs = [[name, value]]
pairs.extend(attrs.lst)
return _format_set_cookie_pairs(pairs)
def parse_cookie_header(str):
"""
Parse a Cookie header value.
Returns a (possibly empty) ODict object.
"""
pairs, off = _read_pairs(str)
return odict.ODict(pairs)
def format_cookie_header(od):
"""
Formats a Cookie header value.
"""
return _format_pairs(od.lst)

View File

@ -13,7 +13,8 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
class ODict(object): class ODict(object):
""" """
A dictionary-like object for managing ordered (key, value) data. A dictionary-like object for managing ordered (key, value) data. Think
about it as a convenient interface to a list of (key, value) tuples.
""" """
def __init__(self, lst=None): def __init__(self, lst=None):
self.lst = lst or [] self.lst = lst or []
@ -64,11 +65,20 @@ class ODict(object):
key, they are cleared. key, they are cleared.
""" """
if isinstance(valuelist, basestring): if isinstance(valuelist, basestring):
raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") raise ValueError(
"Expected list of values instead of string. "
new = self._filter_lst(k, self.lst) "Example: odict['Host'] = ['www.example.com']"
for i in valuelist: )
new.append([k, i]) kc = self._kconv(k)
new = []
for i in self.lst:
if self._kconv(i[0]) == kc:
if valuelist:
new.append([k, valuelist.pop(0)])
else:
new.append(i)
while valuelist:
new.append([k, valuelist.pop(0)])
self.lst = new self.lst = new
def __delitem__(self, k): def __delitem__(self, k):
@ -84,7 +94,7 @@ class ODict(object):
return False return False
def add(self, key, value): def add(self, key, value):
self.lst.append([key, str(value)]) self.lst.append([key, value])
def get(self, k, d=None): def get(self, k, d=None):
if k in self: if k in self:
@ -108,10 +118,19 @@ class ODict(object):
lst = copy.deepcopy(self.lst) lst = copy.deepcopy(self.lst)
return self.__class__(lst) return self.__class__(lst)
def extend(self, other):
"""
Add the contents of other, preserving any duplicates.
"""
self.lst.extend(other.lst)
def __repr__(self): def __repr__(self):
return repr(self.lst)
def format(self):
elements = [] elements = []
for itm in self.lst: for itm in self.lst:
elements.append(itm[0] + ": " + itm[1]) elements.append(itm[0] + ": " + str(itm[1]))
elements.append("") elements.append("")
return "\r\n".join(elements) return "\r\n".join(elements)

View File

@ -8,6 +8,9 @@ def isascii(s):
return False return False
return True return True
# best way to do it in python 2.x
def bytes_to_int(i):
return int(i.encode('hex'), 16)
def cleanBin(s, fixspacing=False): def cleanBin(s, fixspacing=False):
""" """

View File

@ -0,0 +1 @@
from __future__ import (absolute_import, print_function, division)

View File

@ -0,0 +1,80 @@
from netlib import tcp
from base64 import b64encode
from StringIO import StringIO
from . import websockets as ws
import struct
import SocketServer
import os
# Simple websocket client and servers that are used to exercise the functionality in websockets.py
# These are *not* fully RFC6455 compliant
class WebSocketsEchoHandler(tcp.BaseHandler):
def __init__(self, connection, address, server):
super(WebSocketsEchoHandler, self).__init__(connection, address, server)
self.handshake_done = False
def handle(self):
while True:
if not self.handshake_done:
self.handshake()
else:
self.read_next_message()
def read_next_message(self):
decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload
self.on_message(decoded)
def send_message(self, message):
frame = ws.Frame.default(message, from_client = False)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
def handshake(self):
client_hs = ws.read_handshake(self.rfile.read, 1)
key = ws.process_handshake_from_client(client_hs)
response = ws.create_server_handshake(key)
self.wfile.write(response)
self.wfile.flush()
self.handshake_done = True
def on_message(self, message):
if message is not None:
self.send_message(message)
class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
self.version = "13"
self.client_nounce = ws.create_client_nounce()
self.resource = "/"
def connect(self):
super(WebSocketsClient, self).connect()
handshake = ws.create_client_handshake(
self.address.host,
self.address.port,
self.client_nounce,
self.version,
self.resource
)
self.wfile.write(handshake)
self.wfile.flush()
server_handshake = ws.read_handshake(self.rfile.read, 1)
server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce)
if not server_nounce == ws.create_server_nounce(self.client_nounce):
self.close()
def read_next_message(self):
return ws.Frame.from_byte_stream(self.rfile.read).payload
def send_message(self, message):
frame = ws.Frame.default(message, from_client = True)
self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()

View File

@ -0,0 +1,410 @@
from __future__ import absolute_import
import base64
import hashlib
import mimetools
import StringIO
import os
import struct
import io
from .. import utils
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.
#
# Emphassis is on readabilty, simplicity and modularity, not performance or
# completeness
#
# This is a work in progress and does not yet contain all the utilites need to
# create fully complient client/servers #
# Spec: https://tools.ietf.org/html/rfc6455
# The magic sha that websocket servers must know to prove they understand
# RFC6455
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
class WebSocketFrameValidationException(Exception):
pass
class Frame(object):
"""
Represents one websockets frame.
Constructor takes human readable forms of the frame components
from_bytes() is also avaliable.
WebSockets Frame as defined in RFC6455
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|I|S|S|S| (4) |A| (7) | (16/64) |
|N|V|V|V| |S| | (if payload len==126/127) |
| |1|2|3| |K| | |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
| Extended payload length continued, if payload len == 127 |
+ - - - - - - - - - - - - - - - +-------------------------------+
| |Masking-key, if MASK set to 1 |
+-------------------------------+-------------------------------+
| Masking-key (continued) | Payload Data |
+-------------------------------- - - - - - - - - - - - - - - - +
: Payload Data continued ... :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
| Payload Data continued ... |
+---------------------------------------------------------------+
"""
def __init__(
self,
fin, # decmial integer 1 or 0
opcode, # decmial integer 1 - 4
mask_bit, # decimal integer 1 or 0
payload_length_code, # decimal integer 1 - 127
decoded_payload, # bytestring
rsv1 = 0, # decimal integer 1 or 0
rsv2 = 0, # decimal integer 1 or 0
rsv3 = 0, # decimal integer 1 or 0
payload = None, # bytestring
masking_key = None, # 32 bit byte string
actual_payload_length = None, # any decimal integer
):
self.fin = fin
self.rsv1 = rsv1
self.rsv2 = rsv2
self.rsv3 = rsv3
self.opcode = opcode
self.mask_bit = mask_bit
self.payload_length_code = payload_length_code
self.masking_key = masking_key
self.payload = payload
self.decoded_payload = decoded_payload
self.actual_payload_length = actual_payload_length
@classmethod
def from_bytes(cls, bytestring):
"""
Construct a websocket frame from an in-memory bytestring to construct
a frame from a stream of bytes, use from_byte_stream() directly
"""
return cls.from_byte_stream(io.BytesIO(bytestring).read)
@classmethod
def default(cls, message, from_client = False):
"""
Construct a basic websocket frame from some default values.
Creates a non-fragmented text frame.
"""
length_code, actual_length = get_payload_length_pair(message)
if from_client:
mask_bit = 1
masking_key = random_masking_key()
payload = apply_mask(message, masking_key)
else:
mask_bit = 0
masking_key = None
payload = message
return cls(
fin = 1, # final frame
opcode = 1, # text
mask_bit = mask_bit,
payload_length_code = length_code,
payload = payload,
masking_key = masking_key,
decoded_payload = message,
actual_payload_length = actual_length
)
def is_valid(self):
"""
Validate websocket frame invariants, call at anytime to ensure the
Frame has not been corrupted.
"""
try:
assert 0 <= self.fin <= 1
assert 0 <= self.rsv1 <= 1
assert 0 <= self.rsv2 <= 1
assert 0 <= self.rsv3 <= 1
assert 1 <= self.opcode <= 4
assert 0 <= self.mask_bit <= 1
assert 1 <= self.payload_length_code <= 127
if self.mask_bit == 1:
assert 1 <= len(self.masking_key) <= 4
else:
assert self.masking_key is None
assert self.actual_payload_length == len(self.payload)
if self.payload is not None and self.masking_key is not None:
assert apply_mask(self.payload, self.masking_key) == self.decoded_payload
return True
except AssertionError:
return False
def human_readable(self):
return "\n".join([
("fin - " + str(self.fin)),
("rsv1 - " + str(self.rsv1)),
("rsv2 - " + str(self.rsv2)),
("rsv3 - " + str(self.rsv3)),
("opcode - " + str(self.opcode)),
("mask_bit - " + str(self.mask_bit)),
("payload_length_code - " + str(self.payload_length_code)),
("masking_key - " + str(self.masking_key)),
("payload - " + str(self.payload)),
("decoded_payload - " + str(self.decoded_payload)),
("actual_payload_length - " + str(self.actual_payload_length))
])
def safe_to_bytes(self):
if self.is_valid():
return self.to_bytes()
else:
raise WebSocketFrameValidationException()
def to_bytes(self):
"""
Serialize the frame back into the wire format, returns a bytestring
If you haven't checked is_valid_frame() then there's no guarentees
that the serialized bytes will be correct. see safe_to_bytes()
"""
max_16_bit_int = (1 << 16)
max_64_bit_int = (1 << 63)
# break down of the bit-math used to construct the first byte from the
# frame's integer values first shift the significant bit into the
# correct position
# 00000001 << 7 = 10000000
# ...
# then combine:
#
# 10000000 fin
# 01000000 res1
# 00100000 res2
# 00010000 res3
# 00000001 opcode
# -------- OR
# 11110001 = first_byte
first_byte = (self.fin << 7) | (self.rsv1 << 6) |\
(self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode
second_byte = (self.mask_bit << 7) | self.payload_length_code
bytes = chr(first_byte) + chr(second_byte)
if self.actual_payload_length < 126:
pass
elif self.actual_payload_length < max_16_bit_int:
# '!H' pack as 16 bit unsigned short
# add 2 byte extended payload length
bytes += struct.pack('!H', self.actual_payload_length)
elif self.actual_payload_length < max_64_bit_int:
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
bytes += struct.pack('!Q', self.actual_payload_length)
if self.masking_key is not None:
bytes += self.masking_key
bytes += self.payload # already will be encoded if neccessary
return bytes
@classmethod
def from_byte_stream(cls, read_bytes):
"""
read a websockets frame sent by a server or client
read_bytes is a function that can be backed
by sockets or by any byte reader. So this
function may be used to read frames from disk/wire/memory
"""
first_byte = utils.bytes_to_int(read_bytes(1))
second_byte = utils.bytes_to_int(read_bytes(1))
# grab the left most bit
fin = first_byte >> 7
# grab right most 4 bits by and-ing with 00001111
opcode = first_byte & 15
# grab left most bit
mask_bit = second_byte >> 7
# grab the next 7 bits
payload_length = second_byte & 127
# payload_lengthy > 125 indicates you need to read more bytes
# to get the actual payload length
if payload_length <= 125:
actual_payload_length = payload_length
elif payload_length == 126:
actual_payload_length = utils.bytes_to_int(read_bytes(2))
elif payload_length == 127:
actual_payload_length = utils.bytes_to_int(read_bytes(8))
# masking key only present if mask bit set
if mask_bit == 1:
masking_key = read_bytes(4)
else:
masking_key = None
payload = read_bytes(actual_payload_length)
if mask_bit == 1:
decoded_payload = apply_mask(payload, masking_key)
else:
decoded_payload = payload
return cls(
fin = fin,
opcode = opcode,
mask_bit = mask_bit,
payload_length_code = payload_length,
payload = payload,
masking_key = masking_key,
decoded_payload = decoded_payload,
actual_payload_length = actual_payload_length
)
def __eq__(self, other):
return (
self.fin == other.fin and
self.rsv1 == other.rsv1 and
self.rsv2 == other.rsv2 and
self.rsv3 == other.rsv3 and
self.opcode == other.opcode and
self.mask_bit == other.mask_bit and
self.payload_length_code == other.payload_length_code and
self.masking_key == other.masking_key and
self.payload == other.payload and
self.decoded_payload == other.decoded_payload and
self.actual_payload_length == other.actual_payload_length
)
def apply_mask(message, masking_key):
"""
Data sent from the server must be masked to prevent malicious clients
from sending data over the wire in predictable patterns
This method both encodes and decodes strings with the provided mask
Servers do not have to mask data they send to the client.
https://tools.ietf.org/html/rfc6455#section-5.3
"""
masks = [utils.bytes_to_int(byte) for byte in masking_key]
result = ""
for char in message:
result += chr(ord(char) ^ masks[len(result) % 4])
return result
def random_masking_key():
return os.urandom(4)
def create_client_handshake(host, port, key, version, resource):
"""
WebSockets connections are intiated by the client with a valid HTTP
upgrade request
"""
headers = [
('Host', '%s:%s' % (host, port)),
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Key', key),
('Sec-WebSocket-Version', version)
]
request = "GET %s HTTP/1.1" % resource
return build_handshake(headers, request)
def create_server_handshake(key):
"""
The server response is a valid HTTP 101 response.
"""
headers = [
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
('Sec-WebSocket-Accept', create_server_nounce(key))
]
request = "HTTP/1.1 101 Switching Protocols"
return build_handshake(headers, request)
def build_handshake(headers, request):
handshake = [request.encode('utf-8')]
for header, value in headers:
handshake.append(("%s: %s" % (header, value)).encode('utf-8'))
handshake.append(b'\r\n')
return b'\r\n'.join(handshake)
def read_handshake(read_bytes, num_bytes_per_read):
"""
From provided function that reads bytes, read in a
complete HTTP request, which terminates with a CLRF
"""
response = b''
doubleCLRF = b'\r\n\r\n'
while True:
bytes = read_bytes(num_bytes_per_read)
if not bytes:
break
response += bytes
if doubleCLRF in response:
break
return response
def get_payload_length_pair(payload_bytestring):
"""
A websockets frame contains an initial length_code, and an optional
extended length code to represent the actual length if length code is
larger than 125
"""
actual_length = len(payload_bytestring)
if actual_length <= 125:
length_code = actual_length
elif actual_length >= 126 and actual_length <= 65535:
length_code = 126
else:
length_code = 127
return (length_code, actual_length)
def process_handshake_from_client(handshake):
headers = headers_from_http_message(handshake)
if headers.get("Upgrade", None) != "websocket":
return
key = headers['Sec-WebSocket-Key']
return key
def process_handshake_from_server(handshake, client_nounce):
headers = headers_from_http_message(handshake)
if headers.get("Upgrade", None) != "websocket":
return
key = headers['Sec-WebSocket-Accept']
return key
def headers_from_http_message(http_message):
return mimetools.Message(
StringIO.StringIO(http_message.split('\r\n', 1)[1])
)
def create_server_nounce(client_nounce):
return base64.b64encode(
hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex')
)
def create_client_nounce():
return base64.b64encode(os.urandom(16)).decode('utf-8')

View File

@ -1,5 +1,8 @@
from __future__ import (absolute_import, print_function, division) from __future__ import (absolute_import, print_function, division)
import cStringIO, urllib, time, traceback import cStringIO
import urllib
import time
import traceback
from . import odict, tcp from . import odict, tcp
@ -23,15 +26,18 @@ class Request(object):
def date_time_string(): def date_time_string():
"""Return the current date and time formatted for a message header.""" """Return the current date and time formatted for a message header."""
WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
MONTHS = [None, MONTHS = [
None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'
]
now = time.time() now = time.time()
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now)
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
WEEKS[wd], WEEKS[wd],
day, MONTHS[month], year, day, MONTHS[month], year,
hh, mm, ss) hh, mm, ss
)
return s return s
@ -100,6 +106,7 @@ class WSGIAdaptor(object):
status = None, status = None,
headers = None headers = None
) )
def write(data): def write(data):
if not state["headers_sent"]: if not state["headers_sent"]:
soc.write("HTTP/1.1 %s\r\n"%state["status"]) soc.write("HTTP/1.1 %s\r\n"%state["status"])
@ -108,7 +115,7 @@ class WSGIAdaptor(object):
h["Server"] = [self.sversion] h["Server"] = [self.sversion]
if 'date' not in h: if 'date' not in h:
h["Date"] = [date_time_string()] h["Date"] = [date_time_string()]
soc.write(str(h)) soc.write(h.format())
soc.write("\r\n") soc.write("\r\n")
state["headers_sent"] = True state["headers_sent"] = True
if data: if data:
@ -130,7 +137,9 @@ class WSGIAdaptor(object):
errs = cStringIO.StringIO() errs = cStringIO.StringIO()
try: try:
dataiter = self.app(self.make_environ(request, errs, **env), start_response) dataiter = self.app(
self.make_environ(request, errs, **env), start_response
)
for i in dataiter: for i in dataiter:
write(i) write(i)
if not state["headers_sent"]: if not state["headers_sent"]:
@ -143,5 +152,3 @@ class WSGIAdaptor(object):
except Exception: # pragma: no cover except Exception: # pragma: no cover
pass pass
return errs.getvalue() return errs.getvalue()

View File

@ -53,6 +53,7 @@ def test_connection_close():
h["connection"] = ["close"] h["connection"] = ["close"]
assert http.connection_close((1, 1), h) assert http.connection_close((1, 1), h)
def test_get_header_tokens(): def test_get_header_tokens():
h = odict.ODictCaseless() h = odict.ODictCaseless()
assert http.get_header_tokens(h, "foo") == [] assert http.get_header_tokens(h, "foo") == []
@ -69,11 +70,13 @@ def test_read_http_body_request():
r = cStringIO.StringIO("testing") r = cStringIO.StringIO("testing")
assert http.read_http_body(r, h, None, "GET", None, True) == "" assert http.read_http_body(r, h, None, "GET", None, True) == ""
def test_read_http_body_response(): def test_read_http_body_response():
h = odict.ODictCaseless() h = odict.ODictCaseless()
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"
def test_read_http_body(): def test_read_http_body():
# test default case # test default case
h = odict.ODictCaseless() h = odict.ODictCaseless()
@ -115,6 +118,7 @@ def test_read_http_body():
s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"
def test_expected_http_body_size(): def test_expected_http_body_size():
# gibber in the content-length field # gibber in the content-length field
h = odict.ODictCaseless() h = odict.ODictCaseless()
@ -135,6 +139,7 @@ def test_expected_http_body_size():
h = odict.ODictCaseless() h = odict.ODictCaseless()
assert http.expected_http_body_size(h, True, "GET", None) == 0 assert http.expected_http_body_size(h, True, "GET", None) == 0
def test_parse_http_protocol(): def test_parse_http_protocol():
assert http.parse_http_protocol("HTTP/1.1") == (1, 1) assert http.parse_http_protocol("HTTP/1.1") == (1, 1)
assert http.parse_http_protocol("HTTP/0.0") == (0, 0) assert http.parse_http_protocol("HTTP/0.0") == (0, 0)
@ -189,6 +194,7 @@ def test_parse_init_http():
assert not http.parse_init_http("GET /test foo/1.1") assert not http.parse_init_http("GET /test foo/1.1")
assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") assert not http.parse_init_http("GET /test\xc0 HTTP/1.1")
class TestReadHeaders: class TestReadHeaders:
def _read(self, data, verbatim=False): def _read(self, data, verbatim=False):
if not verbatim: if not verbatim:
@ -251,6 +257,7 @@ class TestReadResponseNoContentLength(test.ServerTestBase):
httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None)
assert content == "bar\r\n\r\n" assert content == "bar\r\n\r\n"
def test_read_response(): def test_read_response():
def tst(data, method, limit, include_body=True): def tst(data, method, limit, include_body=True):
data = textwrap.dedent(data) data = textwrap.dedent(data)
@ -351,6 +358,7 @@ def test_parse_url():
# Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt
assert not http.parse_url('http://lo[calhost') assert not http.parse_url('http://lo[calhost')
def test_parse_http_basic_auth(): def test_parse_http_basic_auth():
vals = ("basic", "foo", "bar") vals = ("basic", "foo", "bar")
assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals
@ -358,4 +366,3 @@ def test_parse_http_basic_auth():
assert not http.parse_http_basic_auth("foo bar") assert not http.parse_http_basic_auth("foo bar")
v = "basic " + binascii.b2a_base64("foo") v = "basic " + binascii.b2a_base64("foo")
assert not http.parse_http_basic_auth(v) assert not http.parse_http_basic_auth(v)

220
test/test_http_cookies.py Normal file
View File

@ -0,0 +1,220 @@
import pprint
import nose.tools
from netlib import http_cookies, odict
def test_read_token():
tokens = [
[("foo", 0), ("foo", 3)],
[("foo", 1), ("oo", 3)],
[(" foo", 1), ("foo", 4)],
[(" foo;", 1), ("foo", 4)],
[(" foo=", 1), ("foo", 4)],
[(" foo=bar", 1), ("foo", 4)],
]
for q, a in tokens:
nose.tools.eq_(http_cookies._read_token(*q), a)
def test_read_quoted_string():
tokens = [
[('"foo" x', 0), ("foo", 5)],
[('"f\oo" x', 0), ("foo", 6)],
[(r'"f\\o" x', 0), (r"f\o", 6)],
[(r'"f\\" x', 0), (r"f" + '\\', 5)],
[('"fo\\\"" x', 0), ("fo\"", 6)],
]
for q, a in tokens:
nose.tools.eq_(http_cookies._read_quoted_string(*q), a)
def test_read_pairs():
vals = [
[
"one",
[["one", None]]
],
[
"one=two",
[["one", "two"]]
],
[
"one=",
[["one", ""]]
],
[
'one="two"',
[["one", "two"]]
],
[
'one="two"; three=four',
[["one", "two"], ["three", "four"]]
],
[
'one="two"; three=four; five',
[["one", "two"], ["three", "four"], ["five", None]]
],
[
'one="\\"two"; three=four',
[["one", '"two'], ["three", "four"]]
],
]
for s, lst in vals:
ret, off = http_cookies._read_pairs(s)
nose.tools.eq_(ret, lst)
def test_pairs_roundtrips():
pairs = [
[
"",
[]
],
[
"one=uno",
[["one", "uno"]]
],
[
"one",
[["one", None]]
],
[
"one=uno; two=due",
[["one", "uno"], ["two", "due"]]
],
[
'one="uno"; two="\due"',
[["one", "uno"], ["two", "due"]]
],
[
'one="un\\"o"',
[["one", 'un"o']]
],
[
'one="uno,due"',
[["one", 'uno,due']]
],
[
"one=uno; two; three=tre",
[["one", "uno"], ["two", None], ["three", "tre"]]
],
[
"_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; "
"_rcc2=53VdltWl+Ov6ordflA==;",
[
["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="],
["_rcc2", "53VdltWl+Ov6ordflA=="]
]
]
]
for s, lst in pairs:
ret, off = http_cookies._read_pairs(s)
nose.tools.eq_(ret, lst)
s2 = http_cookies._format_pairs(lst)
ret, off = http_cookies._read_pairs(s2)
nose.tools.eq_(ret, lst)
def test_cookie_roundtrips():
pairs = [
[
"one=uno",
[["one", "uno"]]
],
[
"one=uno; two=due",
[["one", "uno"], ["two", "due"]]
],
]
for s, lst in pairs:
ret = http_cookies.parse_cookie_header(s)
nose.tools.eq_(ret.lst, lst)
s2 = http_cookies.format_cookie_header(ret)
ret = http_cookies.parse_cookie_header(s2)
nose.tools.eq_(ret.lst, lst)
def test_parse_set_cookie_pairs():
pairs = [
[
"one=uno",
[
["one", "uno"]
]
],
[
"one=un\x20",
[
["one", "un\x20"]
]
],
[
"one=uno; foo",
[
["one", "uno"],
["foo", None]
]
],
[
"mun=1.390.f60; "
"expires=sun, 11-oct-2015 12:38:31 gmt; path=/; "
"domain=b.aol.com",
[
["mun", "1.390.f60"],
["expires", "sun, 11-oct-2015 12:38:31 gmt"],
["path", "/"],
["domain", "b.aol.com"]
]
],
[
r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; '
'domain=.rubiconproject.com; '
'expires=mon, 11-may-2015 21:54:57 gmt; '
'path=/',
[
['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'],
['domain', '.rubiconproject.com'],
['expires', 'mon, 11-may-2015 21:54:57 gmt'],
['path', '/']
]
],
]
for s, lst in pairs:
ret = http_cookies._parse_set_cookie_pairs(s)
nose.tools.eq_(ret, lst)
s2 = http_cookies._format_set_cookie_pairs(ret)
ret2 = http_cookies._parse_set_cookie_pairs(s2)
nose.tools.eq_(ret2, lst)
def test_parse_set_cookie_header():
vals = [
[
"", None
],
[
";", None
],
[
"one=uno",
("one", "uno", [])
],
[
"one=uno; foo=bar",
("one", "uno", [["foo", "bar"]])
]
]
for s, expected in vals:
ret = http_cookies.parse_set_cookie_header(s)
if expected:
assert ret[0] == expected[0]
assert ret[1] == expected[1]
nose.tools.eq_(ret[2].lst, expected[2])
s2 = http_cookies.format_set_cookie_header(*ret)
ret2 = http_cookies.parse_set_cookie_header(s2)
assert ret2[0] == expected[0]
assert ret2[1] == expected[1]
nose.tools.eq_(ret2[2].lst, expected[2])
else:
assert ret is None

View File

@ -6,6 +6,11 @@ class TestODict:
def setUp(self): def setUp(self):
self.od = odict.ODict() self.od = odict.ODict()
def test_repr(self):
h = odict.ODict()
h["one"] = ["two"]
assert repr(h)
def test_str_err(self): def test_str_err(self):
h = odict.ODict() h = odict.ODict()
tutils.raises(ValueError, h.__setitem__, "key", "foo") tutils.raises(ValueError, h.__setitem__, "key", "foo")
@ -20,7 +25,7 @@ class TestODict:
"two: tre\r\n", "two: tre\r\n",
"\r\n" "\r\n"
] ]
out = repr(self.od) out = self.od.format()
for i in expected: for i in expected:
assert out.find(i) >= 0 assert out.find(i) >= 0
@ -39,7 +44,7 @@ class TestODict:
self.od["one"] = ["uno"] self.od["one"] = ["uno"]
expected1 = "one: uno\r\n" expected1 = "one: uno\r\n"
expected2 = "\r\n" expected2 = "\r\n"
out = repr(self.od) out = self.od.format()
assert out.find(expected1) >= 0 assert out.find(expected1) >= 0
assert out.find(expected2) >= 0 assert out.find(expected2) >= 0
@ -109,6 +114,12 @@ class TestODict:
assert self.od.get_first("one") == "two" assert self.od.get_first("one") == "two"
assert self.od.get_first("two") == None assert self.od.get_first("two") == None
def test_extend(self):
a = odict.ODict([["a", "b"], ["c", "d"]])
b = odict.ODict([["a", "b"], ["e", "f"]])
a.extend(b)
assert len(a) == 4
assert a["a"] == ["b", "b"]
class TestODictCaseless: class TestODictCaseless:
def setUp(self): def setUp(self):
@ -145,3 +156,18 @@ class TestODictCaseless:
self.od.add("bar", 2) self.od.add("bar", 2)
assert len(self.od.keys()) == 2 assert len(self.od.keys()) == 2
def test_add_order(self):
od = odict.ODict(
[
["one", "uno"],
["two", "due"],
["three", "tre"],
]
)
od["two"] = ["foo", "bar"]
assert od.lst == [
["one", "uno"],
["two", "foo"],
["three", "tre"],
["two", "bar"],
]

90
test/test_websockets.py Normal file
View File

@ -0,0 +1,90 @@
from netlib import tcp
from netlib import test
from netlib.websockets import implementations as impl
from netlib.websockets import websockets as ws
import os
from nose.tools import raises
class TestWebSockets(test.ServerTestBase):
handler = impl.WebSocketsEchoHandler
def random_bytes(self, n = 100):
return os.urandom(n)
def echo(self, msg):
client = impl.WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message(msg)
response = client.read_next_message()
assert response == msg
def test_simple_echo(self):
self.echo("hello I'm the client")
def test_frame_sizes(self):
# length can fit in the the 7 bit payload length
small_msg = self.random_bytes(100)
# 50kb, sligthly larger than can fit in a 7 bit int
medium_msg = self.random_bytes(50000)
# 150kb, slightly larger than can fit in a 16 bit int
large_msg = self.random_bytes(150000)
self.echo(small_msg)
self.echo(medium_msg)
self.echo(large_msg)
def test_default_builder(self):
"""
default builder should always generate valid frames
"""
msg = self.random_bytes()
client_frame = ws.Frame.default(msg, from_client = True)
assert client_frame.is_valid()
server_frame = ws.Frame.default(msg, from_client = False)
assert server_frame.is_valid()
def test_serialization_bijection(self):
"""
Ensure that various frame types can be serialized/deserialized back
and forth between to_bytes() and from_bytes()
"""
for is_client in [True, False]:
for num_bytes in [100, 50000, 150000]:
frame = ws.Frame.default(
self.random_bytes(num_bytes), is_client
)
assert frame == ws.Frame.from_bytes(frame.to_bytes())
bytes = b'\x81\x11cba'
assert ws.Frame.from_bytes(bytes).to_bytes() == bytes
@raises(ws.WebSocketFrameValidationException)
def test_safe_to_bytes(self):
frame = ws.Frame.default(self.random_bytes(8))
frame.actual_payload_length = 1 # corrupt the frame
frame.safe_to_bytes()
class BadHandshakeHandler(impl.WebSocketsEchoHandler):
def handshake(self):
client_hs = ws.read_handshake(self.rfile.read, 1)
ws.process_handshake_from_client(client_hs)
response = ws.create_server_handshake("malformed_key")
self.wfile.write(response)
self.wfile.flush()
self.handshake_done = True
class TestBadHandshake(test.ServerTestBase):
"""
Ensure that the client disconnects if the server handshake is malformed
"""
handler = BadHandshakeHandler
@raises(tcp.NetLibDisconnect)
def test(self):
client = impl.WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message("hello")

View File

@ -100,4 +100,3 @@ class TestWSGI:
start_response(status, response_headers, ei) start_response(status, response_headers, ei)
yield "bbb" yield "bbb"
assert "Internal Server Error" in self._serve(app) assert "Internal Server Error" in self._serve(app)