Merge pull request #85 from Kriechi/http2-wip

add move tests and code from mitmproxy
This commit is contained in:
Maximilian Hils 2015-08-11 10:57:32 +02:00
commit f3a6113391
20 changed files with 898 additions and 425 deletions

View File

@ -2,7 +2,6 @@ from __future__ import (absolute_import, print_function, division)
from argparse import Action, ArgumentTypeError
import binascii
from .. import http
def parse_http_basic_auth(s):
words = s.split()
@ -37,7 +36,6 @@ class NullProxyAuth(object):
"""
Clean up authentication headers, so they're not passed upstream.
"""
pass
def authenticate(self, headers_):
"""

View File

@ -1,6 +1,8 @@
from netlib import odict
class HttpError(Exception):
def __init__(self, code, message):
super(HttpError, self).__init__(message)
self.code = code
@ -11,6 +13,7 @@ class HttpErrorConnClosed(HttpError):
class HttpAuthenticationError(Exception):
def __init__(self, auth_headers=None):
super(HttpAuthenticationError, self).__init__(
"Proxy Authentication Required"

View File

@ -1,28 +1,31 @@
from __future__ import (absolute_import, print_function, division)
import binascii
import collections
import string
import sys
import urlparse
import time
from netlib import odict, utils, tcp, http
from netlib.http import semantics
from .. import status_codes
from ..exceptions import *
class TCPHandler(object):
def __init__(self, rfile, wfile=None):
self.rfile = rfile
self.wfile = wfile
class HTTP1Protocol(semantics.ProtocolMixin):
def __init__(self, tcp_handler=None, rfile=None, wfile=None):
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
def read_request(
self,
include_body=True,
body_size_limit=None,
allow_empty=False,
):
"""
Parse an HTTP request from a file stream
@ -129,8 +132,12 @@ class HTTP1Protocol(semantics.ProtocolMixin):
timestamp_end,
)
def read_response(self, request_method, body_size_limit, include_body=True):
def read_response(
self,
request_method,
body_size_limit,
include_body=True,
):
"""
Returns an http.Response
@ -175,7 +182,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
# read separately
body = None
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
@ -195,7 +201,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
timestamp_end=timestamp_end,
)
def assemble_request(self, request):
assert isinstance(request, semantics.Request)
@ -208,7 +213,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
headers = self._assemble_request_headers(request)
return "%s\r\n%s\r\n%s" % (first_line, headers, request.body)
def assemble_response(self, response):
assert isinstance(response, semantics.Response)
@ -221,7 +225,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
headers = self._assemble_response_headers(response)
return "%s\r\n%s\r\n%s" % (first_line, headers, response.body)
def read_headers(self):
"""
Read a set of headers.
@ -321,9 +324,14 @@ class HTTP1Protocol(semantics.ProtocolMixin):
"HTTP Body too large. Limit is %s," % limit
)
@classmethod
def expected_http_body_size(self, headers, is_request, request_method, response_code):
def expected_http_body_size(
self,
headers,
is_request,
request_method,
response_code,
):
"""
Returns the expected body length:
- a positive integer, if the size is known in advance
@ -359,20 +367,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
return -1
@classmethod
def request_preamble(self, method, resource, http_major="1", http_minor="1"):
return '%s %s HTTP/%s.%s' % (
method, resource, http_major, http_minor
)
@classmethod
def response_preamble(self, code, message=None, http_major="1", http_minor="1"):
if message is None:
message = status_codes.RESPONSES.get(code)
return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message)
@classmethod
def has_chunked_encoding(self, headers):
return "chunked" in [
@ -390,7 +384,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
line = self.tcp_handler.rfile.readline()
return line
def _read_chunked(self, limit, is_request):
"""
Read a chunked HTTP body.
@ -427,7 +420,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
if length == 0:
return
@classmethod
def _parse_http_protocol(self, line):
"""
@ -447,7 +439,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
return None
return major, minor
@classmethod
def _parse_init(self, line):
try:
@ -461,7 +452,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
return None
return method, url, httpversion
@classmethod
def _parse_init_connect(self, line):
"""
@ -489,7 +479,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
return None
return host, port, httpversion
@classmethod
def _parse_init_proxy(self, line):
v = self._parse_init(line)
@ -503,7 +492,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
scheme, host, port, path = parts
return method, scheme, host, port, path, httpversion
@classmethod
def _parse_init_http(self, line):
"""
@ -519,7 +507,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
return None
return method, url, httpversion
@classmethod
def connection_close(self, httpversion, headers):
"""
@ -539,7 +526,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
# be persistent
return httpversion != (1, 1)
@classmethod
def parse_response_line(self, line):
parts = line.strip().split(" ", 2)
@ -554,7 +540,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
return None
return (proto, code, msg)
@classmethod
def _assemble_request_first_line(self, request):
return request.legacy_first_line()
@ -575,7 +560,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
return headers.format()
def _assemble_response_first_line(self, response):
return 'HTTP/%s.%s %s %s' % (
response.httpversion[0],
@ -584,7 +568,11 @@ class HTTP1Protocol(semantics.ProtocolMixin):
response.msg,
)
def _assemble_response_headers(self, response, preserve_transfer_encoding=False):
def _assemble_response_headers(
self,
response,
preserve_transfer_encoding=False,
):
headers = response.headers.copy()
for k in response._headers_to_strip_off:
del headers[k]

View File

@ -9,6 +9,7 @@ from . import frame
class TCPHandler(object):
def __init__(self, rfile, wfile=None):
self.rfile = rfile
self.wfile = wfile
@ -39,7 +40,6 @@ class HTTP2Protocol(semantics.ProtocolMixin):
ALPN_PROTO_H2 = 'h2'
def __init__(
self,
tcp_handler=None,
@ -60,7 +60,12 @@ class HTTP2Protocol(semantics.ProtocolMixin):
self.current_stream_id = None
self.connection_preface_performed = False
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
def read_request(
self,
include_body=True,
body_size_limit=None,
allow_empty=False,
):
self.perform_connection_preface()
timestamp_start = time.time()
@ -92,7 +97,12 @@ class HTTP2Protocol(semantics.ProtocolMixin):
return request
def read_response(self, request_method='', body_size_limit=None, include_body=True):
def read_response(
self,
request_method='',
body_size_limit=None,
include_body=True,
):
self.perform_connection_preface()
timestamp_start = time.time()
@ -123,7 +133,6 @@ class HTTP2Protocol(semantics.ProtocolMixin):
return response
def assemble_request(self, request):
assert isinstance(request, semantics.Request)
@ -133,13 +142,13 @@ class HTTP2Protocol(semantics.ProtocolMixin):
headers = request.headers.copy()
if not ':authority' in headers.keys():
if ':authority' not in headers.keys():
headers.add(':authority', bytes(authority), prepend=True)
if not ':scheme' in headers.keys():
if ':scheme' not in headers.keys():
headers.add(':scheme', bytes(request.scheme), prepend=True)
if not ':path' in headers.keys():
if ':path' not in headers.keys():
headers.add(':path', bytes(request.path), prepend=True)
if not ':method' in headers.keys():
if ':method' not in headers.keys():
headers.add(':method', bytes(request.method), prepend=True)
headers = headers.items()
@ -158,7 +167,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
headers = response.headers.copy()
if not ':status' in headers.keys():
if ':status' not in headers.keys():
headers.add(':status', bytes(str(response.status_code)), prepend=True)
headers = headers.items()

View File

@ -1,13 +1,9 @@
from __future__ import (absolute_import, print_function, division)
import binascii
import collections
import string
import sys
import urllib
import urlparse
from .. import utils, odict
from . import cookies
from . import cookies, exceptions
from netlib import utils, encoding
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
@ -18,11 +14,11 @@ CONTENT_MISSING = 0
class ProtocolMixin(object):
def read_request(self):
raise NotImplemented
def read_request(self, *args, **kwargs): # pragma: no cover
raise NotImplementedError
def read_response(self):
raise NotImplemented
def read_response(self, *args, **kwargs): # pragma: no cover
raise NotImplementedError
def assemble(self, message):
if isinstance(message, Request):
@ -32,14 +28,23 @@ class ProtocolMixin(object):
else:
raise ValueError("HTTP message not supported.")
def assemble_request(self, request):
raise NotImplemented
def assemble_request(self, *args, **kwargs): # pragma: no cover
raise NotImplementedError
def assemble_response(self, response):
raise NotImplemented
def assemble_response(self, *args, **kwargs): # pragma: no cover
raise NotImplementedError
class Request(object):
# This list is adopted legacy code.
# We probably don't need to strip off keep-alive.
_headers_to_strip_off = [
'Proxy-Connection',
'Keep-Alive',
'Connection',
'Transfer-Encoding',
'Upgrade',
]
def __init__(
self,
@ -71,7 +76,6 @@ class Request(object):
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
def __eq__(self, other):
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
@ -114,7 +118,7 @@ class Request(object):
self.httpversion[1],
)
else:
raise http.HttpError(400, "Invalid request form")
raise exceptions.HttpError(400, "Invalid request form")
def anticache(self):
"""
@ -143,7 +147,7 @@ class Request(object):
if self.headers["accept-encoding"]:
self.headers["accept-encoding"] = [
', '.join(
e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])]
e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))]
def update_host_header(self):
"""
@ -317,17 +321,18 @@ class Request(object):
self.scheme, self.host, self.port, self.path = parts
@property
def content(self):
def content(self): # pragma: no cover
# TODO: remove deprecated getter
return self.body
@content.setter
def content(self, content):
def content(self, content): # pragma: no cover
# TODO: remove deprecated setter
self.body = content
class EmptyRequest(Request):
def __init__(self):
super(EmptyRequest, self).__init__(
form_in="",
@ -343,6 +348,11 @@ class EmptyRequest(Request):
class Response(object):
_headers_to_strip_off = [
'Proxy-Connection',
'Alternate-Protocol',
'Alt-Svc',
]
def __init__(
self,
@ -368,7 +378,6 @@ class Response(object):
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
def __eq__(self, other):
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
@ -388,11 +397,9 @@ class Response(object):
status_code=self.status_code,
msg=self.msg,
contenttype=self.headers.get_first(
"content-type", "unknown content type"
),
size=size
)
"content-type",
"unknown content type"),
size=size)
def get_cookies(self):
"""
@ -430,21 +437,21 @@ class Response(object):
self.headers["Set-Cookie"] = values
@property
def content(self):
def content(self): # pragma: no cover
# TODO: remove deprecated getter
return self.body
@content.setter
def content(self, content):
def content(self, content): # pragma: no cover
# TODO: remove deprecated setter
self.body = content
@property
def code(self):
def code(self): # pragma: no cover
# TODO: remove deprecated getter
return self.status_code
@code.setter
def code(self, code):
def code(self, code): # pragma: no cover
# TODO: remove deprecated setter
self.status_code = code

View File

@ -91,8 +91,9 @@ class ODict(object):
self.lst = self._filter_lst(k, self.lst)
def __contains__(self, k):
k = self._kconv(k)
for i in self.lst:
if self._kconv(i[0]) == self._kconv(k):
if self._kconv(i[0]) == k:
return True
return False

View File

@ -69,8 +69,6 @@ def raises(exc, obj, *args, **kwargs):
test_data = utils.Data(__name__)
def treq(content="content", scheme="http", host="address", port=22):
"""
@return: libmproxy.protocol.http.HTTPRequest
@ -119,7 +117,7 @@ def tresp(content="message"):
"OK",
headers,
content,
time.time(),
time.time(),
timestamp_start=time.time(),
timestamp_end=time.time(),
)
return resp

View File

@ -4,6 +4,7 @@ import cgi
import urllib
import urlparse
import string
import re
def isascii(s):
@ -118,6 +119,7 @@ def pretty_size(size):
class Data(object):
def __init__(self, name):
m = __import__(name)
dirname, _ = os.path.split(m.__file__)
@ -136,8 +138,6 @@ class Data(object):
return fullpath
def is_valid_port(port):
if not 0 <= port <= 65535:
return False
@ -220,6 +220,7 @@ def hostport(scheme, host, port):
else:
return "%s:%s" % (host, port)
def unparse_url(scheme, host, port, path=""):
"""
Returns a URL string, constructed from the specified compnents.
@ -234,8 +235,64 @@ def urlencode(s):
s = [tuple(i) for i in s]
return urllib.urlencode(s, False)
def urldecode(s):
"""
Takes a urlencoded string and returns a list of (key, value) tuples.
"""
return cgi.parse_qsl(s, keep_blank_values=True)
def parse_content_type(c):
"""
A simple parser for content-type values. Returns a (type, subtype,
parameters) tuple, where type and subtype are strings, and parameters
is a dict. If the string could not be parsed, return None.
E.g. the following string:
text/html; charset=UTF-8
Returns:
("text", "html", {"charset": "UTF-8"})
"""
parts = c.split(";", 1)
ts = parts[0].split("/", 1)
if len(ts) != 2:
return None
d = {}
if len(parts) == 2:
for i in parts[1].split(";"):
clause = i.split("=", 1)
if len(clause) == 2:
d[clause[0].strip()] = clause[1].strip()
return ts[0].lower(), ts[1].lower(), d
def multipartdecode(hdrs, content):
"""
Takes a multipart boundary encoded string and returns list of (key, value) tuples.
"""
v = hdrs.get_first("content-type")
if v:
v = parse_content_type(v)
if not v:
return []
boundary = v[2].get("boundary")
if not boundary:
return []
rx = re.compile(r'\bname="([^"]+)"')
r = []
for i in content.split("--" + boundary):
parts = i.splitlines()
if len(parts) > 1 and parts[0][0:2] != "--":
match = rx.search(parts[1])
if match:
key = match.group(1)
value = "".join(parts[3 + parts[2:].index(""):])
r.append((key, value))
return r
return []

View File

@ -1,12 +1,11 @@
from __future__ import absolute_import
import base64
import hashlib
import os
import struct
import io
from .protocol import Masker
from netlib import utils, odict, tcp
from netlib import tcp
from netlib import utils
DEFAULT = object()
@ -22,6 +21,7 @@ OPCODE = utils.BiDi(
PONG=0x0a
)
class FrameHeader(object):
def __init__(

View File

@ -2,10 +2,9 @@ from __future__ import absolute_import
import base64
import hashlib
import os
import struct
import io
from netlib import utils, odict, tcp
from netlib import odict
from netlib import utils
# Colleciton of utility functions that implement small portions of the RFC6455
# WebSockets Protocol Useful for building WebSocket clients and servers.
@ -26,6 +25,7 @@ HEADER_WEBSOCKET_KEY = 'sec-websocket-key'
HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept'
HEADER_WEBSOCKET_VERSION = 'sec-websocket-version'
class Masker(object):
"""
@ -53,6 +53,7 @@ class Masker(object):
self.offset += len(ret)
return ret
class WebsocketsProtocol(object):
def __init__(self):

View File

@ -1,18 +1,41 @@
import cStringIO
import textwrap
import binascii
from netlib import http, odict, tcp, tutils
from netlib.http import semantics
from netlib.http.http1 import HTTP1Protocol
from ... import tservers
def mock_protocol(data='', chunked=False):
class NoContentLengthHTTPHandler(tcp.BaseHandler):
def handle(self):
self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n")
self.wfile.flush()
def mock_protocol(data=''):
rfile = cStringIO.StringIO(data)
wfile = cStringIO.StringIO()
return HTTP1Protocol(rfile=rfile, wfile=wfile)
def match_http_string(data):
return textwrap.dedent(data).strip().replace('\n', '\r\n')
def test_stripped_chunked_encoding_no_content():
"""
https://github.com/mitmproxy/mitmproxy/issues/186
"""
r = tutils.treq(content="")
r.headers["Transfer-Encoding"] = ["chunked"]
assert "Content-Length" in mock_protocol()._assemble_request_headers(r)
r = tutils.tresp(content="")
r.headers["Transfer-Encoding"] = ["chunked"]
assert "Content-Length" in mock_protocol()._assemble_response_headers(r)
def test_has_chunked_encoding():
h = odict.ODictCaseless()
@ -75,7 +98,6 @@ def test_connection_close():
assert HTTP1Protocol.connection_close((1, 1), h)
def test_read_http_body_request():
h = odict.ODictCaseless()
data = "testing"
@ -85,7 +107,7 @@ def test_read_http_body_request():
def test_read_http_body_response():
h = odict.ODictCaseless()
data = "testing"
assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing"
assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing"
def test_read_http_body():
@ -129,13 +151,13 @@ def test_read_http_body():
# test no content length: limit > actual content
h = odict.ODictCaseless()
data = "testing"
assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7
assert len(mock_protocol(data).read_http_body(h, 100, "GET", 200, False)) == 7
# test no content length: limit < actual content
data = "testing"
tutils.raises(
http.HttpError,
mock_protocol(data, chunked=True).read_http_body,
mock_protocol(data).read_http_body,
h, 4, "GET", 200, False
)
@ -143,7 +165,7 @@ def test_read_http_body():
h = odict.ODictCaseless()
h["transfer-encoding"] = ["chunked"]
data = "5\r\naaaaa\r\n0\r\n\r\n"
assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa"
assert mock_protocol(data).read_http_body(h, 100, "GET", 200, False) == "aaaaa"
def test_expected_http_body_size():
@ -167,6 +189,13 @@ def test_expected_http_body_size():
assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0
def test_get_request_line():
data = "\nfoo"
p = mock_protocol(data)
assert p._get_request_line() == "foo"
assert not p._get_request_line()
def test_parse_http_protocol():
assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1)
assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0)
@ -269,96 +298,7 @@ class TestReadHeaders:
assert self._read(data) is None
class NoContentLengthHTTPHandler(tcp.BaseHandler):
def handle(self):
self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n")
self.wfile.flush()
class TestReadResponseNoContentLength(tservers.ServerTestBase):
handler = NoContentLengthHTTPHandler
def test_no_content_length(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
resp = HTTP1Protocol(c).read_response("GET", None)
assert resp.body == "bar\r\n\r\n"
def test_read_response():
def tst(data, method, body_size_limit, include_body=True):
data = textwrap.dedent(data)
return mock_protocol(data).read_response(
method, body_size_limit, include_body=include_body
)
tutils.raises("server disconnect", tst, "", "GET", None)
tutils.raises("invalid server response", tst, "foo", "GET", None)
data = """
HTTP/1.1 200 OK
"""
assert tst(data, "GET", None) == http.Response(
(1, 1), 200, 'OK', odict.ODictCaseless(), ''
)
data = """
HTTP/1.1 200
"""
assert tst(data, "GET", None) == http.Response(
(1, 1), 200, '', odict.ODictCaseless(), ''
)
data = """
HTTP/x 200 OK
"""
tutils.raises("invalid http version", tst, data, "GET", None)
data = """
HTTP/1.1 xx OK
"""
tutils.raises("invalid server response", tst, data, "GET", None)
data = """
HTTP/1.1 100 CONTINUE
HTTP/1.1 200 OK
"""
assert tst(data, "GET", None) == http.Response(
(1, 1), 100, 'CONTINUE', odict.ODictCaseless(), ''
)
data = """
HTTP/1.1 200 OK
Content-Length: 3
foo
"""
assert tst(data, "GET", None).body == 'foo'
assert tst(data, "HEAD", None).body == ''
data = """
HTTP/1.1 200 OK
\tContent-Length: 3
foo
"""
tutils.raises("invalid headers", tst, data, "GET", None)
data = """
HTTP/1.1 200 OK
Content-Length: 3
foo
"""
assert tst(data, "GET", None, include_body=False).body is None
def test_get_request_line():
data = "\nfoo"
p = mock_protocol(data)
assert p._get_request_line() == "foo"
assert not p._get_request_line()
class TestReadRequest():
class TestReadRequest(object):
def tst(self, data, **kwargs):
return mock_protocol(data).read_request(**kwargs)
@ -385,6 +325,10 @@ class TestReadRequest():
"\r\n"
)
def test_empty(self):
v = self.tst("", allow_empty=True)
assert isinstance(v, semantics.EmptyRequest)
def test_asterisk_form_in(self):
v = self.tst("OPTIONS * HTTP/1.1")
assert v.form_in == "relative"
@ -427,3 +371,131 @@ class TestReadRequest():
assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n"
assert v.body == "foo"
assert p.tcp_handler.rfile.read(3) == "bar"
class TestReadResponse(object):
def tst(self, data, method, body_size_limit, include_body=True):
data = textwrap.dedent(data)
return mock_protocol(data).read_response(
method, body_size_limit, include_body=include_body
)
def test_errors(self):
tutils.raises("server disconnect", self.tst, "", "GET", None)
tutils.raises("invalid server response", self.tst, "foo", "GET", None)
def test_simple(self):
data = """
HTTP/1.1 200
"""
assert self.tst(data, "GET", None) == http.Response(
(1, 1), 200, '', odict.ODictCaseless(), ''
)
def test_simple_message(self):
data = """
HTTP/1.1 200 OK
"""
assert self.tst(data, "GET", None) == http.Response(
(1, 1), 200, 'OK', odict.ODictCaseless(), ''
)
def test_invalid_http_version(self):
data = """
HTTP/x 200 OK
"""
tutils.raises("invalid http version", self.tst, data, "GET", None)
def test_invalid_status_code(self):
data = """
HTTP/1.1 xx OK
"""
tutils.raises("invalid server response", self.tst, data, "GET", None)
def test_valid_with_continue(self):
data = """
HTTP/1.1 100 CONTINUE
HTTP/1.1 200 OK
"""
assert self.tst(data, "GET", None) == http.Response(
(1, 1), 100, 'CONTINUE', odict.ODictCaseless(), ''
)
def test_simple_body(self):
data = """
HTTP/1.1 200 OK
Content-Length: 3
foo
"""
assert self.tst(data, "GET", None).body == 'foo'
assert self.tst(data, "HEAD", None).body == ''
def test_invalid_headers(self):
data = """
HTTP/1.1 200 OK
\tContent-Length: 3
foo
"""
tutils.raises("invalid headers", self.tst, data, "GET", None)
def test_without_body(self):
data = """
HTTP/1.1 200 OK
Content-Length: 3
foo
"""
assert self.tst(data, "GET", None, include_body=False).body is None
class TestReadResponseNoContentLength(tservers.ServerTestBase):
handler = NoContentLengthHTTPHandler
def test_no_content_length(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
resp = HTTP1Protocol(c).read_response("GET", None)
assert resp.body == "bar\r\n\r\n"
class TestAssembleRequest(object):
def test_simple(self):
req = tutils.treq()
b = HTTP1Protocol().assemble_request(req)
assert b == match_http_string("""
GET /path HTTP/1.1
header: qvalue
Host: address:22
Content-Length: 7
content""")
def test_body_missing(self):
req = tutils.treq(content=semantics.CONTENT_MISSING)
tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req)
def test_not_a_request(self):
tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo')
class TestAssembleResponse(object):
def test_simple(self):
resp = tutils.tresp()
b = HTTP1Protocol().assemble_response(resp)
print(b)
assert b == match_http_string("""
HTTP/1.1 200 OK
header_response: svalue
Content-Length: 7
message""")
def test_body_missing(self):
resp = tutils.tresp(content=semantics.CONTENT_MISSING)
tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp)
def test_not_a_request(self):
tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo')

View File

@ -1,10 +1,25 @@
import OpenSSL
import mock
from netlib import tcp, odict, http, tutils
from netlib.http import http2
from netlib.http.http2 import HTTP2Protocol
from netlib.http.http2.frame import *
from ... import tservers
class TestTCPHandlerWrapper:
def test_wrapped(self):
h = http2.TCPHandler(rfile='foo', wfile='bar')
p = HTTP2Protocol(h)
assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar'
def test_direct(self):
p = HTTP2Protocol(rfile='foo', wfile='bar')
assert isinstance(p.tcp_handler, http2.TCPHandler)
assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar'
class EchoHandler(tcp.BaseHandler):
sni = None
@ -16,10 +31,40 @@ class EchoHandler(tcp.BaseHandler):
self.wfile.flush()
class TestProtocol:
@mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
@mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
def test_perform_connection_preface(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=False)
protocol.connection_preface_performed = True
protocol.perform_connection_preface()
assert not mock_client_method.called
assert not mock_server_method.called
protocol.perform_connection_preface(force=True)
assert mock_client_method.called
assert not mock_server_method.called
@mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
@mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
def test_perform_connection_preface_server(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=True)
protocol.connection_preface_performed = True
protocol.perform_connection_preface()
assert not mock_client_method.called
assert not mock_server_method.called
protocol.perform_connection_preface(force=True)
assert not mock_client_method.called
assert mock_server_method.called
class TestCheckALPNMatch(tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2,
alpn_select=HTTP2Protocol.ALPN_PROTO_H2,
)
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
@ -27,8 +72,8 @@ class TestCheckALPNMatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2])
protocol = http2.HTTP2Protocol(c)
c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2])
protocol = HTTP2Protocol(c)
assert protocol.check_alpn()
@ -43,8 +88,8 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2])
protocol = http2.HTTP2Protocol(c)
c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2])
protocol = HTTP2Protocol(c)
tutils.raises(NotImplementedError, protocol.check_alpn)
@ -76,8 +121,13 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
def test_perform_server_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
protocol = http2.HTTP2Protocol(c)
protocol = HTTP2Protocol(c)
assert not protocol.connection_preface_performed
protocol.perform_server_connection_preface()
assert protocol.connection_preface_performed
tutils.raises(tcp.NetLibIncomplete, protocol.perform_server_connection_preface, force=True)
class TestPerformClientConnectionPreface(tservers.ServerTestBase):
@ -107,13 +157,16 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
def test_perform_client_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
protocol = http2.HTTP2Protocol(c)
protocol = HTTP2Protocol(c)
assert not protocol.connection_preface_performed
protocol.perform_client_connection_preface()
assert protocol.connection_preface_performed
class TestClientStreamIds():
c = tcp.TCPClient(("127.0.0.1", 0))
protocol = http2.HTTP2Protocol(c)
protocol = HTTP2Protocol(c)
def test_client_stream_ids(self):
assert self.protocol.current_stream_id is None
@ -127,7 +180,7 @@ class TestClientStreamIds():
class TestServerStreamIds():
c = tcp.TCPClient(("127.0.0.1", 0))
protocol = http2.HTTP2Protocol(c, is_server=True)
protocol = HTTP2Protocol(c, is_server=True)
def test_server_stream_ids(self):
assert self.protocol.current_stream_id is None
@ -154,7 +207,7 @@ class TestApplySettings(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c)
protocol = HTTP2Protocol(c)
protocol._apply_settings({
SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo',
@ -182,13 +235,13 @@ class TestCreateHeaders():
(b':scheme', b'https'),
(b'foo', b'bar')]
bytes = http2.HTTP2Protocol(self.c)._create_headers(
bytes = HTTP2Protocol(self.c)._create_headers(
headers, 1, end_stream=True)
assert b''.join(bytes) ==\
'000014010500000001824488355217caf3a69a3f87408294e7838c767f'\
.decode('hex')
bytes = http2.HTTP2Protocol(self.c)._create_headers(
bytes = HTTP2Protocol(self.c)._create_headers(
headers, 1, end_stream=False)
assert b''.join(bytes) ==\
'000014010400000001824488355217caf3a69a3f87408294e7838c767f'\
@ -199,7 +252,7 @@ class TestCreateHeaders():
class TestCreateBody():
c = tcp.TCPClient(("127.0.0.1", 0))
protocol = http2.HTTP2Protocol(c)
protocol = HTTP2Protocol(c)
def test_create_body_empty(self):
bytes = self.protocol._create_body(b'', 1)
@ -215,98 +268,6 @@ class TestCreateBody():
# TODO: add test for too large frames
class TestAssembleRequest():
c = tcp.TCPClient(("127.0.0.1", 0))
def test_assemble_request_simple(self):
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
'',
'GET',
'https',
'',
'',
'/',
(2, 0),
None,
None,
))
assert len(bytes) == 1
assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex')
def test_assemble_request_with_body(self):
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
'',
'GET',
'https',
'',
'',
'/',
(2, 0),
odict.ODictCaseless([('foo', 'bar')]),
'foobar',
))
assert len(bytes) == 2
assert bytes[0] ==\
'0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex')
assert bytes[1] ==\
'000006000100000001666f6f626172'.decode('hex')
class TestReadResponse(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
b'00000801040000000188628594e78c767f'.decode('hex'))
self.wfile.write(
b'000006000100000001666f6f626172'.decode('hex'))
self.wfile.flush()
ssl = True
def test_read_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response()
assert resp.httpversion == (2, 0)
assert resp.status_code == 200
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b'foobar'
class TestReadEmptyResponse(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
b'00000801050000000188628594e78c767f'.decode('hex'))
self.wfile.flush()
ssl = True
def test_read_empty_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response()
assert resp.stream_id
assert resp.httpversion == (2, 0)
assert resp.status_code == 200
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b''
class TestReadRequest(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
@ -323,7 +284,7 @@ class TestReadRequest(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = http2.HTTP2Protocol(c, is_server=True)
protocol = HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True
resp = protocol.read_request()
@ -333,11 +294,138 @@ class TestReadRequest(tservers.ServerTestBase):
assert resp.body == b'foobar'
class TestCreateResponse():
class TestReadResponse(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
b'00000801040000000188628594e78c767f'.decode('hex'))
self.wfile.write(
b'000006000100000001666f6f626172'.decode('hex'))
self.wfile.flush()
ssl = True
def test_read_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = HTTP2Protocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response()
assert resp.httpversion == (2, 0)
assert resp.status_code == 200
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b'foobar'
assert resp.timestamp_end
def test_read_response_no_body(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = HTTP2Protocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response(include_body=False)
assert resp.httpversion == (2, 0)
assert resp.status_code == 200
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b'foobar' # TODO: this should be true: assert resp.body == http.CONTENT_MISSING
assert not resp.timestamp_end
class TestReadEmptyResponse(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
b'00000801050000000188628594e78c767f'.decode('hex'))
self.wfile.flush()
ssl = True
def test_read_empty_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
protocol = HTTP2Protocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response()
assert resp.stream_id
assert resp.httpversion == (2, 0)
assert resp.status_code == 200
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b''
class TestAssembleRequest(object):
c = tcp.TCPClient(("127.0.0.1", 0))
def test_create_response_simple(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
def test_request_simple(self):
bytes = HTTP2Protocol(self.c).assemble_request(http.Request(
'',
'GET',
'https',
'',
'',
'/',
(2, 0),
None,
None,
))
assert len(bytes) == 1
assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex')
def test_request_with_stream_id(self):
req = http.Request(
'',
'GET',
'https',
'',
'',
'/',
(2, 0),
None,
None,
)
req.stream_id = 0x42
bytes = HTTP2Protocol(self.c).assemble_request(req)
assert len(bytes) == 1
print(bytes[0].encode('hex'))
assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex')
def test_request_with_body(self):
bytes = HTTP2Protocol(self.c).assemble_request(http.Request(
'',
'GET',
'https',
'',
'',
'/',
(2, 0),
odict.ODictCaseless([('foo', 'bar')]),
'foobar',
))
assert len(bytes) == 2
assert bytes[0] ==\
'0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex')
assert bytes[1] ==\
'000006000100000001666f6f626172'.decode('hex')
class TestAssembleResponse(object):
c = tcp.TCPClient(("127.0.0.1", 0))
def test_simple(self):
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
(2, 0),
200,
))
@ -345,8 +433,19 @@ class TestCreateResponse():
assert bytes[0] ==\
'00000101050000000288'.decode('hex')
def test_create_response_with_body(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
def test_with_stream_id(self):
resp = http.Response(
(2, 0),
200,
)
resp.stream_id = 0x42
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp)
assert len(bytes) == 1
assert bytes[0] ==\
'00000101050000004288'.decode('hex')
def test_with_body(self):
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
(2, 0),
200,
'',

View File

@ -1,6 +1,27 @@
from netlib.http.exceptions import *
from netlib import odict
def test_HttpAuthenticationError():
x = HttpAuthenticationError({"foo": "bar"})
class TestHttpError:
def test_simple(self):
e = HttpError(404, "Not found")
assert str(e)
class TestHttpAuthenticationError:
def test_init(self):
headers = odict.ODictCaseless([("foo", "bar")])
x = HttpAuthenticationError(headers)
assert str(x)
assert "foo" in x.headers
assert isinstance(x.headers, odict.ODictCaseless)
assert x.code == 407
assert x.headers == headers
print(x.headers.keys())
assert "foo" in x.headers.keys()
def test_header_conversion(self):
headers = {"foo": "bar"}
x = HttpAuthenticationError(headers)
assert isinstance(x.headers, odict.ODictCaseless)
assert x.headers.lst == headers.items()
def test_repr(self):
assert repr(HttpAuthenticationError()) == "Proxy Authentication Required"

View File

@ -1,18 +1,275 @@
import cStringIO
import textwrap
import binascii
from mock import MagicMock
import mock
from netlib import http, odict, tcp, tutils
from netlib.http import http1
from netlib import http
from netlib import odict
from netlib import tutils
from netlib import utils
from netlib.http import semantics
from netlib.http.semantics import CONTENT_MISSING
from .. import tservers
def test_httperror():
e = http.exceptions.HttpError(404, "Not found")
assert str(e)
class TestProtocolMixin(object):
@mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
@mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
def test_assemble_request(self, mock_request_method, mock_response_method):
p = semantics.ProtocolMixin()
p.assemble(tutils.treq())
assert mock_request_method.called
assert not mock_response_method.called
@mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
@mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
def test_assemble_response(self, mock_request_method, mock_response_method):
p = semantics.ProtocolMixin()
p.assemble(tutils.tresp())
assert not mock_request_method.called
assert mock_response_method.called
def test_assemble_foo(self):
p = semantics.ProtocolMixin()
tutils.raises(ValueError, p.assemble, 'foo')
class TestRequest(object):
def test_repr(self):
r = tutils.treq()
assert repr(r)
def test_headers_odict(self):
tutils.raises(AssertionError, semantics.Request,
'form_in',
'method',
'scheme',
'host',
'port',
'path',
(1, 1),
'foobar',
)
req = semantics.Request(
'form_in',
'method',
'scheme',
'host',
'port',
'path',
(1, 1),
)
assert isinstance(req.headers, odict.ODictCaseless)
def test_equal(self):
a = tutils.treq()
b = tutils.treq()
assert a == b
assert not a == 'foo'
assert not b == 'foo'
assert not 'foo' == a
assert not 'foo' == b
def test_legacy_first_line(self):
req = tutils.treq()
req.form_in = 'relative'
assert req.legacy_first_line() == "GET /path HTTP/1.1"
req.form_in = 'authority'
assert req.legacy_first_line() == "GET address:22 HTTP/1.1"
req.form_in = 'absolute'
assert req.legacy_first_line() == "GET http://address:22/path HTTP/1.1"
req.form_in = 'foobar'
tutils.raises(http.HttpError, req.legacy_first_line)
def test_anticache(self):
req = tutils.treq()
req.headers.add("If-Modified-Since", "foo")
req.headers.add("If-None-Match", "bar")
req.anticache()
assert "If-Modified-Since" not in req.headers
assert "If-None-Match" not in req.headers
def test_anticomp(self):
req = tutils.treq()
req.headers.add("Accept-Encoding", "foobar")
req.anticomp()
assert req.headers["Accept-Encoding"] == ["identity"]
def test_constrain_encoding(self):
req = tutils.treq()
req.headers.add("Accept-Encoding", "identity, gzip, foo")
req.constrain_encoding()
assert "foo" not in req.headers.get_first("Accept-Encoding")
def test_update_host(self):
req = tutils.treq()
req.headers.add("Host", "")
req.host = "foobar"
req.update_host_header()
assert req.headers.get_first("Host") == "foobar"
def test_get_form(self):
req = tutils.treq()
assert req.get_form() == odict.ODict()
@mock.patch("netlib.http.semantics.Request.get_form_multipart")
@mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart):
req = tutils.treq()
assert req.get_form() == odict.ODict()
req = tutils.treq()
req.body = "foobar"
req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED]
req.get_form()
assert req.get_form_urlencoded.called
assert not req.get_form_multipart.called
@mock.patch("netlib.http.semantics.Request.get_form_multipart")
@mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart):
req = tutils.treq()
req.body = "foobar"
req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART]
req.get_form()
assert not req.get_form_urlencoded.called
assert req.get_form_multipart.called
def test_get_form_urlencoded(self):
req = tutils.treq("foobar")
assert req.get_form_urlencoded() == odict.ODict()
req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED]
assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body))
def test_get_form_multipart(self):
req = tutils.treq("foobar")
assert req.get_form_multipart() == odict.ODict()
req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART]
assert req.get_form_multipart() == odict.ODict(
utils.multipartdecode(
req.headers,
req.body))
def test_set_form_urlencoded(self):
req = tutils.treq()
req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')]))
assert req.headers.get_first("Content-Type") == semantics.HDR_FORM_URLENCODED
assert req.body
def test_get_path_components(self):
req = tutils.treq()
assert req.get_path_components()
# TODO: add meaningful assertions
def test_set_path_components(self):
req = tutils.treq()
req.set_path_components(["foo", "bar"])
# TODO: add meaningful assertions
def test_get_query(self):
req = tutils.treq()
assert req.get_query().lst == []
req.url = "http://localhost:80/foo?bar=42"
assert req.get_query().lst == [("bar", "42")]
def test_set_query(self):
req = tutils.treq()
req.set_query(odict.ODict([]))
def test_pretty_host(self):
r = tutils.treq()
assert r.pretty_host(True) == "address"
assert r.pretty_host(False) == "address"
r.headers["host"] = ["other"]
assert r.pretty_host(True) == "other"
assert r.pretty_host(False) == "address"
r.host = None
assert r.pretty_host(True) == "other"
assert r.pretty_host(False) is None
del r.headers["host"]
assert r.pretty_host(True) is None
assert r.pretty_host(False) is None
# Invalid IDNA
r.headers["host"] = [".disqus.com"]
assert r.pretty_host(True) == ".disqus.com"
def test_pretty_url(self):
req = tutils.treq()
req.form_out = "authority"
assert req.pretty_url(True) == "address:22"
assert req.pretty_url(False) == "address:22"
req.form_out = "relative"
assert req.pretty_url(True) == "http://address:22/path"
assert req.pretty_url(False) == "http://address:22/path"
def test_get_cookies_none(self):
h = odict.ODictCaseless()
r = tutils.treq()
r.headers = h
assert len(r.get_cookies()) == 0
def test_get_cookies_single(self):
h = odict.ODictCaseless()
h["Cookie"] = ["cookiename=cookievalue"]
r = tutils.treq()
r.headers = h
result = r.get_cookies()
assert len(result) == 1
assert result['cookiename'] == ['cookievalue']
def test_get_cookies_double(self):
h = odict.ODictCaseless()
h["Cookie"] = [
"cookiename=cookievalue;othercookiename=othercookievalue"
]
r = tutils.treq()
r.headers = h
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['cookievalue']
assert result['othercookiename'] == ['othercookievalue']
def test_get_cookies_withequalsign(self):
h = odict.ODictCaseless()
h["Cookie"] = [
"cookiename=coo=kievalue;othercookiename=othercookievalue"
]
r = tutils.treq()
r.headers = h
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['coo=kievalue']
assert result['othercookiename'] == ['othercookievalue']
def test_set_cookies(self):
h = odict.ODictCaseless()
h["Cookie"] = ["cookiename=cookievalue"]
r = tutils.treq()
r.headers = h
result = r.get_cookies()
result["cookiename"] = ["foo"]
r.set_cookies(result)
assert r.get_cookies()["cookiename"] == ["foo"]
def test_set_url(self):
r = tutils.treq_absolute()
r.url = "https://otheraddress:42/ORLY"
assert r.scheme == "https"
assert r.host == "otheraddress"
assert r.port == 42
assert r.path == "/ORLY"
try:
r.url = "//localhost:80/foo@bar"
assert False
except:
assert True
class TestRequest:
# def test_asterisk_form_in(self):
# f = tutils.tflow(req=None)
# protocol = mock_protocol("OPTIONS * HTTP/1.1")
@ -92,105 +349,35 @@ class TestRequest:
# "Host: address\r\n"
# "Content-Length: 0\r\n\r\n")
def test_set_url(self):
r = tutils.treq_absolute()
r.url = "https://otheraddress:42/ORLY"
assert r.scheme == "https"
assert r.host == "otheraddress"
assert r.port == 42
assert r.path == "/ORLY"
def test_repr(self):
r = tutils.treq()
assert repr(r)
def test_pretty_host(self):
r = tutils.treq()
assert r.pretty_host(True) == "address"
assert r.pretty_host(False) == "address"
r.headers["host"] = ["other"]
assert r.pretty_host(True) == "other"
assert r.pretty_host(False) == "address"
r.host = None
assert r.pretty_host(True) == "other"
assert r.pretty_host(False) is None
del r.headers["host"]
assert r.pretty_host(True) is None
assert r.pretty_host(False) is None
# Invalid IDNA
r.headers["host"] = [".disqus.com"]
assert r.pretty_host(True) == ".disqus.com"
def test_get_form_for_urlencoded(self):
r = tutils.treq()
r.headers.add("content-type", "application/x-www-form-urlencoded")
r.get_form_urlencoded = MagicMock()
r.get_form()
assert r.get_form_urlencoded.called
def test_get_form_for_multipart(self):
r = tutils.treq()
r.headers.add("content-type", "multipart/form-data")
r.get_form_multipart = MagicMock()
r.get_form()
assert r.get_form_multipart.called
def test_get_cookies_none(self):
h = odict.ODictCaseless()
r = tutils.treq()
r.headers = h
assert len(r.get_cookies()) == 0
def test_get_cookies_single(self):
h = odict.ODictCaseless()
h["Cookie"] = ["cookiename=cookievalue"]
r = tutils.treq()
r.headers = h
result = r.get_cookies()
assert len(result) == 1
assert result['cookiename'] == ['cookievalue']
def test_get_cookies_double(self):
h = odict.ODictCaseless()
h["Cookie"] = [
"cookiename=cookievalue;othercookiename=othercookievalue"
]
r = tutils.treq()
r.headers = h
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['cookievalue']
assert result['othercookiename'] == ['othercookievalue']
def test_get_cookies_withequalsign(self):
h = odict.ODictCaseless()
h["Cookie"] = [
"cookiename=coo=kievalue;othercookiename=othercookievalue"
]
r = tutils.treq()
r.headers = h
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['coo=kievalue']
assert result['othercookiename'] == ['othercookievalue']
def test_set_cookies(self):
h = odict.ODictCaseless()
h["Cookie"] = ["cookiename=cookievalue"]
r = tutils.treq()
r.headers = h
result = r.get_cookies()
result["cookiename"] = ["foo"]
r.set_cookies(result)
assert r.get_cookies()["cookiename"] == ["foo"]
class TestEmptyRequest(object):
def test_init(self):
req = semantics.EmptyRequest()
assert req
class TestResponse(object):
def test_headers_odict(self):
tutils.raises(AssertionError, semantics.Response,
(1, 1),
200,
headers='foobar',
)
resp = semantics.Response(
(1, 1),
200,
)
assert isinstance(resp.headers, odict.ODictCaseless)
def test_equal(self):
a = tutils.tresp()
b = tutils.tresp()
assert a == b
assert not a == 'foo'
assert not b == 'foo'
assert not 'foo' == a
assert not 'foo' == b
def test_repr(self):
r = tutils.tresp()
assert "unknown content type" in repr(r)

View File

@ -1,5 +1,6 @@
from netlib import encoding
def test_identity():
assert "string" == encoding.decode("identity", "string")
assert "string" == encoding.encode("identity", "string")

View File

@ -44,7 +44,11 @@ def test_client_greeting_assert_socks5():
assert False
raw = tutils.treader("XX")
tutils.raises(socks.SocksError, socks.ClientGreeting.from_file, raw, fail_early=True)
tutils.raises(
socks.SocksError,
socks.ClientGreeting.from_file,
raw,
fail_early=True)
def test_server_greeting():

View File

@ -1,5 +1,3 @@
import urlparse
from netlib import utils, odict, tutils
@ -30,8 +28,6 @@ def test_pretty_size():
assert utils.pretty_size(1024 * 1024) == "1MB"
def test_parse_url():
assert not utils.parse_url("")
@ -86,7 +82,6 @@ def test_urlencode():
assert utils.urlencode([('foo', 'bar')])
def test_urldecode():
s = "one=two&three=four"
assert len(utils.urldecode(s)) == 2
@ -101,3 +96,31 @@ def test_get_header_tokens():
assert utils.get_header_tokens(h, "foo") == ["bar", "voing"]
h["foo"] = ["bar, voing", "oink"]
assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"]
def test_multipartdecode():
boundary = 'somefancyboundary'
headers = odict.ODict(
[('content-type', ('multipart/form-data; boundary=%s' % boundary))])
content = "--{0}\n" \
"Content-Disposition: form-data; name=\"field1\"\n\n" \
"value1\n" \
"--{0}\n" \
"Content-Disposition: form-data; name=\"field2\"\n\n" \
"value2\n" \
"--{0}--".format(boundary)
form = utils.multipartdecode(headers, content)
assert len(form) == 2
assert form[0] == ('field1', 'value1')
assert form[1] == ('field2', 'value2')
def test_parse_content_type():
p = utils.parse_content_type
assert p("text/html") == ("text", "html", {})
assert p("text") is None
v = p("text/html; charset=UTF-8")
assert v == ('text', 'html', {'charset': 'UTF-8'})

View File

@ -3,7 +3,8 @@ import threading
import Queue
import cStringIO
import OpenSSL
from netlib import tcp, certutils, tutils
from netlib import tcp
from netlib import tutils
class ServerThread(threading.Thread):

View File

@ -2,7 +2,10 @@ import os
from nose.tools import raises
from netlib import tcp, http, websockets, tutils
from netlib import tcp
from netlib import tutils
from netlib import websockets
from netlib.http import status_codes
from netlib.http.exceptions import *
from netlib.http.http1 import HTTP1Protocol
from .. import tservers
@ -38,7 +41,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
req = http1_protocol.read_request()
key = self.protocol.check_client_handshake(req.headers)
preamble = http1_protocol.response_preamble(101)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble + "\r\n")
headers = self.protocol.server_handshake_headers(key)
self.wfile.write(headers.format() + "\r\n")
@ -62,7 +65,7 @@ class WebSocketsClient(tcp.TCPClient):
http1_protocol = HTTP1Protocol(self)
preamble = http1_protocol.request_preamble("GET", "/")
preamble = 'GET / HTTP/1.1'
self.wfile.write(preamble + "\r\n")
headers = self.protocol.client_handshake_headers()
self.client_nonce = headers.get_first("sec-websocket-key")
@ -162,7 +165,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler):
client_hs = http1_protocol.read_request()
self.protocol.check_client_handshake(client_hs.headers)
preamble = http1_protocol.response_preamble(101)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble + "\r\n")
headers = self.protocol.server_handshake_headers("malformed key")
self.wfile.write(headers.format() + "\r\n")