diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 9a227010d..29b9eb3c2 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -2,7 +2,6 @@ from __future__ import (absolute_import, print_function, division) from argparse import Action, ArgumentTypeError import binascii -from .. import http def parse_http_basic_auth(s): words = s.split() @@ -37,7 +36,6 @@ class NullProxyAuth(object): """ Clean up authentication headers, so they're not passed upstream. """ - pass def authenticate(self, headers_): """ diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 7cd26c122..987a7908d 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -1,6 +1,8 @@ from netlib import odict + class HttpError(Exception): + def __init__(self, code, message): super(HttpError, self).__init__(message) self.code = code @@ -11,6 +13,7 @@ class HttpErrorConnClosed(HttpError): class HttpAuthenticationError(Exception): + def __init__(self, auth_headers=None): super(HttpAuthenticationError, self).__init__( "Proxy Authentication Required" diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 2e85a762e..8eeb77449 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -1,28 +1,31 @@ from __future__ import (absolute_import, print_function, division) -import binascii -import collections import string import sys -import urlparse import time from netlib import odict, utils, tcp, http from netlib.http import semantics -from .. import status_codes from ..exceptions import * + class TCPHandler(object): + def __init__(self, rfile, wfile=None): self.rfile = rfile self.wfile = wfile + class HTTP1Protocol(semantics.ProtocolMixin): def __init__(self, tcp_handler=None, rfile=None, wfile=None): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - - def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + def read_request( + self, + include_body=True, + body_size_limit=None, + allow_empty=False, + ): """ Parse an HTTP request from a file stream @@ -129,8 +132,12 @@ class HTTP1Protocol(semantics.ProtocolMixin): timestamp_end, ) - - def read_response(self, request_method, body_size_limit, include_body=True): + def read_response( + self, + request_method, + body_size_limit, + include_body=True, + ): """ Returns an http.Response @@ -175,7 +182,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): # read separately body = None - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): # more accurate timestamp_start timestamp_start = self.tcp_handler.rfile.first_byte_timestamp @@ -195,7 +201,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): timestamp_end=timestamp_end, ) - def assemble_request(self, request): assert isinstance(request, semantics.Request) @@ -208,7 +213,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): headers = self._assemble_request_headers(request) return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) - def assemble_response(self, response): assert isinstance(response, semantics.Response) @@ -221,7 +225,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): headers = self._assemble_response_headers(response) return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) - def read_headers(self): """ Read a set of headers. @@ -266,7 +269,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): response_code, is_request, max_chunk_size=None - ): + ): """ Read an HTTP message body: headers: An ODictCaseless object @@ -321,9 +324,14 @@ class HTTP1Protocol(semantics.ProtocolMixin): "HTTP Body too large. Limit is %s," % limit ) - @classmethod - def expected_http_body_size(self, headers, is_request, request_method, response_code): + def expected_http_body_size( + self, + headers, + is_request, + request_method, + response_code, + ): """ Returns the expected body length: - a positive integer, if the size is known in advance @@ -359,20 +367,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return -1 - @classmethod - def request_preamble(self, method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - - @classmethod - def response_preamble(self, code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) - - @classmethod def has_chunked_encoding(self, headers): return "chunked" in [ @@ -390,7 +384,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): line = self.tcp_handler.rfile.readline() return line - def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -427,7 +420,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): if length == 0: return - @classmethod def _parse_http_protocol(self, line): """ @@ -447,7 +439,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return major, minor - @classmethod def _parse_init(self, line): try: @@ -461,7 +452,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return method, url, httpversion - @classmethod def _parse_init_connect(self, line): """ @@ -489,7 +479,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return host, port, httpversion - @classmethod def _parse_init_proxy(self, line): v = self._parse_init(line) @@ -503,7 +492,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): scheme, host, port, path = parts return method, scheme, host, port, path, httpversion - @classmethod def _parse_init_http(self, line): """ @@ -519,7 +507,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return method, url, httpversion - @classmethod def connection_close(self, httpversion, headers): """ @@ -539,7 +526,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): # be persistent return httpversion != (1, 1) - @classmethod def parse_response_line(self, line): parts = line.strip().split(" ", 2) @@ -554,7 +540,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None return (proto, code, msg) - @classmethod def _assemble_request_first_line(self, request): return request.legacy_first_line() @@ -575,7 +560,6 @@ class HTTP1Protocol(semantics.ProtocolMixin): return headers.format() - def _assemble_response_first_line(self, response): return 'HTTP/%s.%s %s %s' % ( response.httpversion[0], @@ -584,7 +568,11 @@ class HTTP1Protocol(semantics.ProtocolMixin): response.msg, ) - def _assemble_response_headers(self, response, preserve_transfer_encoding=False): + def _assemble_response_headers( + self, + response, + preserve_transfer_encoding=False, + ): headers = response.headers.copy() for k in response._headers_to_strip_off: del headers[k] diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index f7e604710..aa1fbae46 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -117,7 +117,7 @@ class Frame(object): return "\n".join([ "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), self.payload_human_readable(), "===============================================================", ]) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index a1ca4a182..c2ad5edde 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -9,6 +9,7 @@ from . import frame class TCPHandler(object): + def __init__(self, rfile, wfile=None): self.rfile = rfile self.wfile = wfile @@ -39,7 +40,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): ALPN_PROTO_H2 = 'h2' - def __init__( self, tcp_handler=None, @@ -60,7 +60,12 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.current_stream_id = None self.connection_preface_performed = False - def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + def read_request( + self, + include_body=True, + body_size_limit=None, + allow_empty=False, + ): self.perform_connection_preface() timestamp_start = time.time() @@ -92,7 +97,12 @@ class HTTP2Protocol(semantics.ProtocolMixin): return request - def read_response(self, request_method='', body_size_limit=None, include_body=True): + def read_response( + self, + request_method='', + body_size_limit=None, + include_body=True, + ): self.perform_connection_preface() timestamp_start = time.time() @@ -123,7 +133,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): return response - def assemble_request(self, request): assert isinstance(request, semantics.Request) @@ -133,13 +142,13 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = request.headers.copy() - if not ':authority' in headers.keys(): + if ':authority' not in headers.keys(): headers.add(':authority', bytes(authority), prepend=True) - if not ':scheme' in headers.keys(): + if ':scheme' not in headers.keys(): headers.add(':scheme', bytes(request.scheme), prepend=True) - if not ':path' in headers.keys(): + if ':path' not in headers.keys(): headers.add(':path', bytes(request.path), prepend=True) - if not ':method' in headers.keys(): + if ':method' not in headers.keys(): headers.add(':method', bytes(request.method), prepend=True) headers = headers.items() @@ -158,7 +167,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = response.headers.copy() - if not ':status' in headers.keys(): + if ':status' not in headers.keys(): headers.add(':status', bytes(str(response.status_code)), prepend=True) headers = headers.items() diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e7ae2b5f8..76213cd1a 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -1,13 +1,9 @@ from __future__ import (absolute_import, print_function, division) -import binascii -import collections -import string -import sys import urllib import urlparse from .. import utils, odict -from . import cookies +from . import cookies, exceptions from netlib import utils, encoding HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" @@ -18,11 +14,11 @@ CONTENT_MISSING = 0 class ProtocolMixin(object): - def read_request(self): - raise NotImplemented + def read_request(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError - def read_response(self): - raise NotImplemented + def read_response(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError def assemble(self, message): if isinstance(message, Request): @@ -32,14 +28,23 @@ class ProtocolMixin(object): else: raise ValueError("HTTP message not supported.") - def assemble_request(self, request): - raise NotImplemented + def assemble_request(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError - def assemble_response(self, response): - raise NotImplemented + def assemble_response(self, *args, **kwargs): # pragma: no cover + raise NotImplementedError class Request(object): + # This list is adopted legacy code. + # We probably don't need to strip off keep-alive. + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade', + ] def __init__( self, @@ -71,7 +76,6 @@ class Request(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -114,7 +118,7 @@ class Request(object): self.httpversion[1], ) else: - raise http.HttpError(400, "Invalid request form") + raise exceptions.HttpError(400, "Invalid request form") def anticache(self): """ @@ -143,7 +147,7 @@ class Request(object): if self.headers["accept-encoding"]: self.headers["accept-encoding"] = [ ', '.join( - e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])] + e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))] def update_host_header(self): """ @@ -317,17 +321,18 @@ class Request(object): self.scheme, self.host, self.port, self.path = parts @property - def content(self): + def content(self): # pragma: no cover # TODO: remove deprecated getter return self.body @content.setter - def content(self, content): + def content(self, content): # pragma: no cover # TODO: remove deprecated setter self.body = content class EmptyRequest(Request): + def __init__(self): super(EmptyRequest, self).__init__( form_in="", @@ -339,10 +344,15 @@ class EmptyRequest(Request): httpversion=(0, 0), headers=odict.ODictCaseless(), body="", - ) + ) class Response(object): + _headers_to_strip_off = [ + 'Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc', + ] def __init__( self, @@ -368,7 +378,6 @@ class Response(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -388,11 +397,9 @@ class Response(object): status_code=self.status_code, msg=self.msg, contenttype=self.headers.get_first( - "content-type", "unknown content type" - ), - size=size - ) - + "content-type", + "unknown content type"), + size=size) def get_cookies(self): """ @@ -430,21 +437,21 @@ class Response(object): self.headers["Set-Cookie"] = values @property - def content(self): + def content(self): # pragma: no cover # TODO: remove deprecated getter return self.body @content.setter - def content(self, content): + def content(self, content): # pragma: no cover # TODO: remove deprecated setter self.body = content @property - def code(self): + def code(self): # pragma: no cover # TODO: remove deprecated getter return self.status_code @code.setter - def code(self, code): + def code(self, code): # pragma: no cover # TODO: remove deprecated setter self.status_code = code diff --git a/netlib/odict.py b/netlib/odict.py index d02de08d1..11d5d52aa 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -91,8 +91,9 @@ class ODict(object): self.lst = self._filter_lst(k, self.lst) def __contains__(self, k): + k = self._kconv(k) for i in self.lst: - if self._kconv(i[0]) == self._kconv(k): + if self._kconv(i[0]) == k: return True return False diff --git a/netlib/tutils.py b/netlib/tutils.py index 5018b9e81..7434c1080 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -69,8 +69,6 @@ def raises(exc, obj, *args, **kwargs): test_data = utils.Data(__name__) - - def treq(content="content", scheme="http", host="address", port=22): """ @return: libmproxy.protocol.http.HTTPRequest @@ -119,7 +117,7 @@ def tresp(content="message"): "OK", headers, content, - time.time(), - time.time(), + timestamp_start=time.time(), + timestamp_end=time.time(), ) return resp diff --git a/netlib/utils.py b/netlib/utils.py index 35ea0ec7a..31dcd622f 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -4,6 +4,7 @@ import cgi import urllib import urlparse import string +import re def isascii(s): @@ -118,6 +119,7 @@ def pretty_size(size): class Data(object): + def __init__(self, name): m = __import__(name) dirname, _ = os.path.split(m.__file__) @@ -136,8 +138,6 @@ class Data(object): return fullpath - - def is_valid_port(port): if not 0 <= port <= 65535: return False @@ -220,6 +220,7 @@ def hostport(scheme, host, port): else: return "%s:%s" % (host, port) + def unparse_url(scheme, host, port, path=""): """ Returns a URL string, constructed from the specified compnents. @@ -234,8 +235,64 @@ def urlencode(s): s = [tuple(i) for i in s] return urllib.urlencode(s, False) + def urldecode(s): """ Takes a urlencoded string and returns a list of (key, value) tuples. """ return cgi.parse_qsl(s, keep_blank_values=True) + + +def parse_content_type(c): + """ + A simple parser for content-type values. Returns a (type, subtype, + parameters) tuple, where type and subtype are strings, and parameters + is a dict. If the string could not be parsed, return None. + + E.g. the following string: + + text/html; charset=UTF-8 + + Returns: + + ("text", "html", {"charset": "UTF-8"}) + """ + parts = c.split(";", 1) + ts = parts[0].split("/", 1) + if len(ts) != 2: + return None + d = {} + if len(parts) == 2: + for i in parts[1].split(";"): + clause = i.split("=", 1) + if len(clause) == 2: + d[clause[0].strip()] = clause[1].strip() + return ts[0].lower(), ts[1].lower(), d + + +def multipartdecode(hdrs, content): + """ + Takes a multipart boundary encoded string and returns list of (key, value) tuples. + """ + v = hdrs.get_first("content-type") + if v: + v = parse_content_type(v) + if not v: + return [] + boundary = v[2].get("boundary") + if not boundary: + return [] + + rx = re.compile(r'\bname="([^"]+)"') + r = [] + + for i in content.split("--" + boundary): + parts = i.splitlines() + if len(parts) > 1 and parts[0][0:2] != "--": + match = rx.search(parts[1]) + if match: + key = match.group(1) + value = "".join(parts[3 + parts[2:].index(""):]) + r.append((key, value)) + return r + return [] diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 49d8ee10a..1c4a03b2e 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -1,12 +1,11 @@ from __future__ import absolute_import -import base64 -import hashlib import os import struct import io from .protocol import Masker -from netlib import utils, odict, tcp +from netlib import tcp +from netlib import utils DEFAULT = object() @@ -22,6 +21,7 @@ OPCODE = utils.BiDi( PONG=0x0a ) + class FrameHeader(object): def __init__( diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 29b4db3db..6ce32eac5 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -2,10 +2,9 @@ from __future__ import absolute_import import base64 import hashlib import os -import struct -import io -from netlib import utils, odict, tcp +from netlib import odict +from netlib import utils # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -26,6 +25,7 @@ HEADER_WEBSOCKET_KEY = 'sec-websocket-key' HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept' HEADER_WEBSOCKET_VERSION = 'sec-websocket-version' + class Masker(object): """ @@ -53,6 +53,7 @@ class Masker(object): self.offset += len(ret) return ret + class WebsocketsProtocol(object): def __init__(self): diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index e3c3ff433..af77c55f8 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -1,18 +1,41 @@ import cStringIO import textwrap -import binascii from netlib import http, odict, tcp, tutils +from netlib.http import semantics from netlib.http.http1 import HTTP1Protocol from ... import tservers -def mock_protocol(data='', chunked=False): +class NoContentLengthHTTPHandler(tcp.BaseHandler): + def handle(self): + self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") + self.wfile.flush() + + +def mock_protocol(data=''): rfile = cStringIO.StringIO(data) wfile = cStringIO.StringIO() return HTTP1Protocol(rfile=rfile, wfile=wfile) +def match_http_string(data): + return textwrap.dedent(data).strip().replace('\n', '\r\n') + + +def test_stripped_chunked_encoding_no_content(): + """ + https://github.com/mitmproxy/mitmproxy/issues/186 + """ + + r = tutils.treq(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in mock_protocol()._assemble_request_headers(r) + + r = tutils.tresp(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in mock_protocol()._assemble_response_headers(r) + def test_has_chunked_encoding(): h = odict.ODictCaseless() @@ -75,7 +98,6 @@ def test_connection_close(): assert HTTP1Protocol.connection_close((1, 1), h) - def test_read_http_body_request(): h = odict.ODictCaseless() data = "testing" @@ -85,7 +107,7 @@ def test_read_http_body_request(): def test_read_http_body_response(): h = odict.ODictCaseless() data = "testing" - assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing" def test_read_http_body(): @@ -129,13 +151,13 @@ def test_read_http_body(): # test no content length: limit > actual content h = odict.ODictCaseless() data = "testing" - assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7 + assert len(mock_protocol(data).read_http_body(h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content data = "testing" tutils.raises( http.HttpError, - mock_protocol(data, chunked=True).read_http_body, + mock_protocol(data).read_http_body, h, 4, "GET", 200, False ) @@ -143,7 +165,7 @@ def test_read_http_body(): h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] data = "5\r\naaaaa\r\n0\r\n\r\n" - assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa" + assert mock_protocol(data).read_http_body(h, 100, "GET", 200, False) == "aaaaa" def test_expected_http_body_size(): @@ -167,6 +189,13 @@ def test_expected_http_body_size(): assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0 +def test_get_request_line(): + data = "\nfoo" + p = mock_protocol(data) + assert p._get_request_line() == "foo" + assert not p._get_request_line() + + def test_parse_http_protocol(): assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1) assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0) @@ -269,96 +298,7 @@ class TestReadHeaders: assert self._read(data) is None -class NoContentLengthHTTPHandler(tcp.BaseHandler): - - def handle(self): - self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") - self.wfile.flush() - - -class TestReadResponseNoContentLength(tservers.ServerTestBase): - handler = NoContentLengthHTTPHandler - - def test_no_content_length(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - resp = HTTP1Protocol(c).read_response("GET", None) - assert resp.body == "bar\r\n\r\n" - - -def test_read_response(): - def tst(data, method, body_size_limit, include_body=True): - data = textwrap.dedent(data) - return mock_protocol(data).read_response( - method, body_size_limit, include_body=include_body - ) - - tutils.raises("server disconnect", tst, "", "GET", None) - tutils.raises("invalid server response", tst, "foo", "GET", None) - data = """ - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http.Response( - (1, 1), 200, 'OK', odict.ODictCaseless(), '' - ) - data = """ - HTTP/1.1 200 - """ - assert tst(data, "GET", None) == http.Response( - (1, 1), 200, '', odict.ODictCaseless(), '' - ) - data = """ - HTTP/x 200 OK - """ - tutils.raises("invalid http version", tst, data, "GET", None) - data = """ - HTTP/1.1 xx OK - """ - tutils.raises("invalid server response", tst, data, "GET", None) - - data = """ - HTTP/1.1 100 CONTINUE - - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http.Response( - (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' - ) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None).body == 'foo' - assert tst(data, "HEAD", None).body == '' - - data = """ - HTTP/1.1 200 OK - \tContent-Length: 3 - - foo - """ - tutils.raises("invalid headers", tst, data, "GET", None) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None, include_body=False).body is None - - -def test_get_request_line(): - data = "\nfoo" - p = mock_protocol(data) - assert p._get_request_line() == "foo" - assert not p._get_request_line() - - -class TestReadRequest(): +class TestReadRequest(object): def tst(self, data, **kwargs): return mock_protocol(data).read_request(**kwargs) @@ -385,6 +325,10 @@ class TestReadRequest(): "\r\n" ) + def test_empty(self): + v = self.tst("", allow_empty=True) + assert isinstance(v, semantics.EmptyRequest) + def test_asterisk_form_in(self): v = self.tst("OPTIONS * HTTP/1.1") assert v.form_in == "relative" @@ -427,3 +371,131 @@ class TestReadRequest(): assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" assert v.body == "foo" assert p.tcp_handler.rfile.read(3) == "bar" + + +class TestReadResponse(object): + def tst(self, data, method, body_size_limit, include_body=True): + data = textwrap.dedent(data) + return mock_protocol(data).read_response( + method, body_size_limit, include_body=include_body + ) + + def test_errors(self): + tutils.raises("server disconnect", self.tst, "", "GET", None) + tutils.raises("invalid server response", self.tst, "foo", "GET", None) + + def test_simple(self): + data = """ + HTTP/1.1 200 + """ + assert self.tst(data, "GET", None) == http.Response( + (1, 1), 200, '', odict.ODictCaseless(), '' + ) + + def test_simple_message(self): + data = """ + HTTP/1.1 200 OK + """ + assert self.tst(data, "GET", None) == http.Response( + (1, 1), 200, 'OK', odict.ODictCaseless(), '' + ) + + def test_invalid_http_version(self): + data = """ + HTTP/x 200 OK + """ + tutils.raises("invalid http version", self.tst, data, "GET", None) + + def test_invalid_status_code(self): + data = """ + HTTP/1.1 xx OK + """ + tutils.raises("invalid server response", self.tst, data, "GET", None) + + def test_valid_with_continue(self): + data = """ + HTTP/1.1 100 CONTINUE + + HTTP/1.1 200 OK + """ + assert self.tst(data, "GET", None) == http.Response( + (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' + ) + + def test_simple_body(self): + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert self.tst(data, "GET", None).body == 'foo' + assert self.tst(data, "HEAD", None).body == '' + + def test_invalid_headers(self): + data = """ + HTTP/1.1 200 OK + \tContent-Length: 3 + + foo + """ + tutils.raises("invalid headers", self.tst, data, "GET", None) + + def test_without_body(self): + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert self.tst(data, "GET", None, include_body=False).body is None + + +class TestReadResponseNoContentLength(tservers.ServerTestBase): + handler = NoContentLengthHTTPHandler + + def test_no_content_length(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + resp = HTTP1Protocol(c).read_response("GET", None) + assert resp.body == "bar\r\n\r\n" + + +class TestAssembleRequest(object): + def test_simple(self): + req = tutils.treq() + b = HTTP1Protocol().assemble_request(req) + assert b == match_http_string(""" + GET /path HTTP/1.1 + header: qvalue + Host: address:22 + Content-Length: 7 + + content""") + + def test_body_missing(self): + req = tutils.treq(content=semantics.CONTENT_MISSING) + tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req) + + def test_not_a_request(self): + tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo') + + +class TestAssembleResponse(object): + def test_simple(self): + resp = tutils.tresp() + b = HTTP1Protocol().assemble_response(resp) + print(b) + assert b == match_http_string(""" + HTTP/1.1 200 OK + header_response: svalue + Content-Length: 7 + + message""") + + def test_body_missing(self): + resp = tutils.tresp(content=semantics.CONTENT_MISSING) + tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp) + + def test_not_a_request(self): + tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo') diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 8a27bbb1b..3044179f0 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,10 +1,25 @@ import OpenSSL +import mock from netlib import tcp, odict, http, tutils from netlib.http import http2 +from netlib.http.http2 import HTTP2Protocol from netlib.http.http2.frame import * from ... import tservers +class TestTCPHandlerWrapper: + def test_wrapped(self): + h = http2.TCPHandler(rfile='foo', wfile='bar') + p = HTTP2Protocol(h) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + + def test_direct(self): + p = HTTP2Protocol(rfile='foo', wfile='bar') + assert isinstance(p.tcp_handler, http2.TCPHandler) + assert p.tcp_handler.rfile == 'foo' + assert p.tcp_handler.wfile == 'bar' + class EchoHandler(tcp.BaseHandler): sni = None @@ -16,10 +31,40 @@ class EchoHandler(tcp.BaseHandler): self.wfile.flush() +class TestProtocol: + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=False) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert mock_client_method.called + assert not mock_server_method.called + + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): + protocol = HTTP2Protocol(is_server=True) + protocol.connection_preface_performed = True + + protocol.perform_connection_preface() + assert not mock_client_method.called + assert not mock_server_method.called + + protocol.perform_connection_preface(force=True) + assert not mock_client_method.called + assert mock_server_method.called + + class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, + alpn_select=HTTP2Protocol.ALPN_PROTO_H2, ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -27,8 +72,8 @@ class TestCheckALPNMatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) + c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + protocol = HTTP2Protocol(c) assert protocol.check_alpn() @@ -43,8 +88,8 @@ class TestCheckALPNMismatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) + c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + protocol = HTTP2Protocol(c) tutils.raises(NotImplementedError, protocol.check_alpn) @@ -76,8 +121,13 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase): def test_perform_server_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed protocol.perform_server_connection_preface() + assert protocol.connection_preface_performed + + tutils.raises(tcp.NetLibIncomplete, protocol.perform_server_connection_preface, force=True) class TestPerformClientConnectionPreface(tservers.ServerTestBase): @@ -107,13 +157,16 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase): def test_perform_client_connection_preface(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) + + assert not protocol.connection_preface_performed protocol.perform_client_connection_preface() + assert protocol.connection_preface_performed class TestClientStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) def test_client_stream_ids(self): assert self.protocol.current_stream_id is None @@ -127,7 +180,7 @@ class TestClientStreamIds(): class TestServerStreamIds(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c, is_server=True) + protocol = HTTP2Protocol(c, is_server=True) def test_server_stream_ids(self): assert self.protocol.current_stream_id is None @@ -154,7 +207,7 @@ class TestApplySettings(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) protocol._apply_settings({ SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', @@ -182,13 +235,13 @@ class TestCreateHeaders(): (b':scheme', b'https'), (b'foo', b'bar')] - bytes = http2.HTTP2Protocol(self.c)._create_headers( + bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=True) assert b''.join(bytes) ==\ '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ .decode('hex') - bytes = http2.HTTP2Protocol(self.c)._create_headers( + bytes = HTTP2Protocol(self.c)._create_headers( headers, 1, end_stream=False) assert b''.join(bytes) ==\ '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ @@ -199,7 +252,7 @@ class TestCreateHeaders(): class TestCreateBody(): c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) + protocol = HTTP2Protocol(c) def test_create_body_empty(self): bytes = self.protocol._create_body(b'', 1) @@ -215,98 +268,6 @@ class TestCreateBody(): # TODO: add test for too large frames -class TestAssembleRequest(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_assemble_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - None, - None, - )) - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - - def test_assemble_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request( - '', - 'GET', - 'https', - '', - '', - '/', - (2, 0), - odict.ODictCaseless([('foo', 'bar')]), - 'foobar', - )) - assert len(bytes) == 2 - assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') - - -class TestReadResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801040000000188628594e78c767f'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - protocol.connection_preface_performed = True - - resp = protocol.read_response() - - assert resp.httpversion == (2, 0) - assert resp.status_code == 200 - assert resp.msg == "" - assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] - assert resp.body == b'foobar' - - -class TestReadEmptyResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801050000000188628594e78c767f'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_empty_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - protocol.connection_preface_performed = True - - resp = protocol.read_response() - - assert resp.stream_id - assert resp.httpversion == (2, 0) - assert resp.status_code == 200 - assert resp.msg == "" - assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] - assert resp.body == b'' - - class TestReadRequest(tservers.ServerTestBase): class handler(tcp.BaseHandler): @@ -323,7 +284,7 @@ class TestReadRequest(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c, is_server=True) + protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True resp = protocol.read_request() @@ -333,11 +294,138 @@ class TestReadRequest(tservers.ServerTestBase): assert resp.body == b'foobar' -class TestCreateResponse(): +class TestReadResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801040000000188628594e78c767f'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response() + + assert resp.httpversion == (2, 0) + assert resp.status_code == 200 + assert resp.msg == "" + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.body == b'foobar' + assert resp.timestamp_end + + def test_read_response_no_body(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response(include_body=False) + + assert resp.httpversion == (2, 0) + assert resp.status_code == 200 + assert resp.msg == "" + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.body == b'foobar' # TODO: this should be true: assert resp.body == http.CONTENT_MISSING + assert not resp.timestamp_end + + +class TestReadEmptyResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801050000000188628594e78c767f'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_empty_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = HTTP2Protocol(c) + protocol.connection_preface_performed = True + + resp = protocol.read_response() + + assert resp.stream_id + assert resp.httpversion == (2, 0) + assert resp.status_code == 200 + assert resp.msg == "" + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.body == b'' + + +class TestAssembleRequest(object): c = tcp.TCPClient(("127.0.0.1", 0)) - def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + def test_request_simple(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + )) + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + + def test_request_with_stream_id(self): + req = http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + None, + None, + ) + req.stream_id = 0x42 + bytes = HTTP2Protocol(self.c).assemble_request(req) + assert len(bytes) == 1 + print(bytes[0].encode('hex')) + assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex') + + def test_request_with_body(self): + bytes = HTTP2Protocol(self.c).assemble_request(http.Request( + '', + 'GET', + 'https', + '', + '', + '/', + (2, 0), + odict.ODictCaseless([('foo', 'bar')]), + 'foobar', + )) + assert len(bytes) == 2 + assert bytes[0] ==\ + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') + + +class TestAssembleResponse(object): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_simple(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( (2, 0), 200, )) @@ -345,8 +433,19 @@ class TestCreateResponse(): assert bytes[0] ==\ '00000101050000000288'.decode('hex') - def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( + def test_with_stream_id(self): + resp = http.Response( + (2, 0), + 200, + ) + resp.stream_id = 0x42 + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000004288'.decode('hex') + + def test_with_body(self): + bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( (2, 0), 200, '', diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py index aa57f8310..0131c7efa 100644 --- a/test/http/test_exceptions.py +++ b/test/http/test_exceptions.py @@ -1,6 +1,27 @@ from netlib.http.exceptions import * +from netlib import odict -def test_HttpAuthenticationError(): - x = HttpAuthenticationError({"foo": "bar"}) - assert str(x) - assert "foo" in x.headers +class TestHttpError: + def test_simple(self): + e = HttpError(404, "Not found") + assert str(e) + +class TestHttpAuthenticationError: + def test_init(self): + headers = odict.ODictCaseless([("foo", "bar")]) + x = HttpAuthenticationError(headers) + assert str(x) + assert isinstance(x.headers, odict.ODictCaseless) + assert x.code == 407 + assert x.headers == headers + print(x.headers.keys()) + assert "foo" in x.headers.keys() + + def test_header_conversion(self): + headers = {"foo": "bar"} + x = HttpAuthenticationError(headers) + assert isinstance(x.headers, odict.ODictCaseless) + assert x.headers.lst == headers.items() + + def test_repr(self): + assert repr(HttpAuthenticationError()) == "Proxy Authentication Required" diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py index d58a44d2e..7ef69dcff 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_semantics.py @@ -1,18 +1,275 @@ -import cStringIO -import textwrap -import binascii -from mock import MagicMock +import mock -from netlib import http, odict, tcp, tutils -from netlib.http import http1 +from netlib import http +from netlib import odict +from netlib import tutils +from netlib import utils +from netlib.http import semantics from netlib.http.semantics import CONTENT_MISSING -from .. import tservers -def test_httperror(): - e = http.exceptions.HttpError(404, "Not found") - assert str(e) +class TestProtocolMixin(object): + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") + def test_assemble_request(self, mock_request_method, mock_response_method): + p = semantics.ProtocolMixin() + p.assemble(tutils.treq()) + assert mock_request_method.called + assert not mock_response_method.called + + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") + @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") + def test_assemble_response(self, mock_request_method, mock_response_method): + p = semantics.ProtocolMixin() + p.assemble(tutils.tresp()) + assert not mock_request_method.called + assert mock_response_method.called + + def test_assemble_foo(self): + p = semantics.ProtocolMixin() + tutils.raises(ValueError, p.assemble, 'foo') + +class TestRequest(object): + def test_repr(self): + r = tutils.treq() + assert repr(r) + + def test_headers_odict(self): + tutils.raises(AssertionError, semantics.Request, + 'form_in', + 'method', + 'scheme', + 'host', + 'port', + 'path', + (1, 1), + 'foobar', + ) + + req = semantics.Request( + 'form_in', + 'method', + 'scheme', + 'host', + 'port', + 'path', + (1, 1), + ) + assert isinstance(req.headers, odict.ODictCaseless) + + def test_equal(self): + a = tutils.treq() + b = tutils.treq() + assert a == b + + assert not a == 'foo' + assert not b == 'foo' + assert not 'foo' == a + assert not 'foo' == b + + def test_legacy_first_line(self): + req = tutils.treq() + + req.form_in = 'relative' + assert req.legacy_first_line() == "GET /path HTTP/1.1" + + req.form_in = 'authority' + assert req.legacy_first_line() == "GET address:22 HTTP/1.1" + + req.form_in = 'absolute' + assert req.legacy_first_line() == "GET http://address:22/path HTTP/1.1" + + req.form_in = 'foobar' + tutils.raises(http.HttpError, req.legacy_first_line) + + def test_anticache(self): + req = tutils.treq() + req.headers.add("If-Modified-Since", "foo") + req.headers.add("If-None-Match", "bar") + req.anticache() + assert "If-Modified-Since" not in req.headers + assert "If-None-Match" not in req.headers + + def test_anticomp(self): + req = tutils.treq() + req.headers.add("Accept-Encoding", "foobar") + req.anticomp() + assert req.headers["Accept-Encoding"] == ["identity"] + + def test_constrain_encoding(self): + req = tutils.treq() + req.headers.add("Accept-Encoding", "identity, gzip, foo") + req.constrain_encoding() + assert "foo" not in req.headers.get_first("Accept-Encoding") + + def test_update_host(self): + req = tutils.treq() + req.headers.add("Host", "") + req.host = "foobar" + req.update_host_header() + assert req.headers.get_first("Host") == "foobar" + + def test_get_form(self): + req = tutils.treq() + assert req.get_form() == odict.ODict() + + @mock.patch("netlib.http.semantics.Request.get_form_multipart") + @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart): + req = tutils.treq() + assert req.get_form() == odict.ODict() + + req = tutils.treq() + req.body = "foobar" + req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED] + req.get_form() + assert req.get_form_urlencoded.called + assert not req.get_form_multipart.called + + @mock.patch("netlib.http.semantics.Request.get_form_multipart") + @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): + req = tutils.treq() + req.body = "foobar" + req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART] + req.get_form() + assert not req.get_form_urlencoded.called + assert req.get_form_multipart.called + + def test_get_form_urlencoded(self): + req = tutils.treq("foobar") + assert req.get_form_urlencoded() == odict.ODict() + + req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED] + assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body)) + + def test_get_form_multipart(self): + req = tutils.treq("foobar") + assert req.get_form_multipart() == odict.ODict() + + req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART] + assert req.get_form_multipart() == odict.ODict( + utils.multipartdecode( + req.headers, + req.body)) + + def test_set_form_urlencoded(self): + req = tutils.treq() + req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')])) + assert req.headers.get_first("Content-Type") == semantics.HDR_FORM_URLENCODED + assert req.body + + def test_get_path_components(self): + req = tutils.treq() + assert req.get_path_components() + # TODO: add meaningful assertions + + def test_set_path_components(self): + req = tutils.treq() + req.set_path_components(["foo", "bar"]) + # TODO: add meaningful assertions + + def test_get_query(self): + req = tutils.treq() + assert req.get_query().lst == [] + + req.url = "http://localhost:80/foo?bar=42" + assert req.get_query().lst == [("bar", "42")] + + def test_set_query(self): + req = tutils.treq() + req.set_query(odict.ODict([])) + + def test_pretty_host(self): + r = tutils.treq() + assert r.pretty_host(True) == "address" + assert r.pretty_host(False) == "address" + r.headers["host"] = ["other"] + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) == "address" + r.host = None + assert r.pretty_host(True) == "other" + assert r.pretty_host(False) is None + del r.headers["host"] + assert r.pretty_host(True) is None + assert r.pretty_host(False) is None + + # Invalid IDNA + r.headers["host"] = [".disqus.com"] + assert r.pretty_host(True) == ".disqus.com" + + def test_pretty_url(self): + req = tutils.treq() + req.form_out = "authority" + assert req.pretty_url(True) == "address:22" + assert req.pretty_url(False) == "address:22" + + req.form_out = "relative" + assert req.pretty_url(True) == "http://address:22/path" + assert req.pretty_url(False) == "http://address:22/path" + + def test_get_cookies_none(self): + h = odict.ODictCaseless() + r = tutils.treq() + r.headers = h + assert len(r.get_cookies()) == 0 + + def test_get_cookies_single(self): + h = odict.ODictCaseless() + h["Cookie"] = ["cookiename=cookievalue"] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 1 + assert result['cookiename'] == ['cookievalue'] + + def test_get_cookies_double(self): + h = odict.ODictCaseless() + h["Cookie"] = [ + "cookiename=cookievalue;othercookiename=othercookievalue" + ] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['cookievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_get_cookies_withequalsign(self): + h = odict.ODictCaseless() + h["Cookie"] = [ + "cookiename=coo=kievalue;othercookiename=othercookievalue" + ] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + assert len(result) == 2 + assert result['cookiename'] == ['coo=kievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_set_cookies(self): + h = odict.ODictCaseless() + h["Cookie"] = ["cookiename=cookievalue"] + r = tutils.treq() + r.headers = h + result = r.get_cookies() + result["cookiename"] = ["foo"] + r.set_cookies(result) + assert r.get_cookies()["cookiename"] == ["foo"] + + def test_set_url(self): + r = tutils.treq_absolute() + r.url = "https://otheraddress:42/ORLY" + assert r.scheme == "https" + assert r.host == "otheraddress" + assert r.port == 42 + assert r.path == "/ORLY" + + try: + r.url = "//localhost:80/foo@bar" + assert False + except: + assert True -class TestRequest: # def test_asterisk_form_in(self): # f = tutils.tflow(req=None) # protocol = mock_protocol("OPTIONS * HTTP/1.1") @@ -92,105 +349,35 @@ class TestRequest: # "Host: address\r\n" # "Content-Length: 0\r\n\r\n") - def test_set_url(self): - r = tutils.treq_absolute() - r.url = "https://otheraddress:42/ORLY" - assert r.scheme == "https" - assert r.host == "otheraddress" - assert r.port == 42 - assert r.path == "/ORLY" - - def test_repr(self): - r = tutils.treq() - assert repr(r) - - def test_pretty_host(self): - r = tutils.treq() - assert r.pretty_host(True) == "address" - assert r.pretty_host(False) == "address" - r.headers["host"] = ["other"] - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) == "address" - r.host = None - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) is None - del r.headers["host"] - assert r.pretty_host(True) is None - assert r.pretty_host(False) is None - - # Invalid IDNA - r.headers["host"] = [".disqus.com"] - assert r.pretty_host(True) == ".disqus.com" - - def test_get_form_for_urlencoded(self): - r = tutils.treq() - r.headers.add("content-type", "application/x-www-form-urlencoded") - r.get_form_urlencoded = MagicMock() - - r.get_form() - - assert r.get_form_urlencoded.called - - def test_get_form_for_multipart(self): - r = tutils.treq() - r.headers.add("content-type", "multipart/form-data") - r.get_form_multipart = MagicMock() - - r.get_form() - - assert r.get_form_multipart.called - - def test_get_cookies_none(self): - h = odict.ODictCaseless() - r = tutils.treq() - r.headers = h - assert len(r.get_cookies()) == 0 - - def test_get_cookies_single(self): - h = odict.ODictCaseless() - h["Cookie"] = ["cookiename=cookievalue"] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - assert len(result) == 1 - assert result['cookiename'] == ['cookievalue'] - - def test_get_cookies_double(self): - h = odict.ODictCaseless() - h["Cookie"] = [ - "cookiename=cookievalue;othercookiename=othercookievalue" - ] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['cookievalue'] - assert result['othercookiename'] == ['othercookievalue'] - - def test_get_cookies_withequalsign(self): - h = odict.ODictCaseless() - h["Cookie"] = [ - "cookiename=coo=kievalue;othercookiename=othercookievalue" - ] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['coo=kievalue'] - assert result['othercookiename'] == ['othercookievalue'] - - def test_set_cookies(self): - h = odict.ODictCaseless() - h["Cookie"] = ["cookiename=cookievalue"] - r = tutils.treq() - r.headers = h - result = r.get_cookies() - result["cookiename"] = ["foo"] - r.set_cookies(result) - assert r.get_cookies()["cookiename"] == ["foo"] - +class TestEmptyRequest(object): + def test_init(self): + req = semantics.EmptyRequest() + assert req class TestResponse(object): + def test_headers_odict(self): + tutils.raises(AssertionError, semantics.Response, + (1, 1), + 200, + headers='foobar', + ) + + resp = semantics.Response( + (1, 1), + 200, + ) + assert isinstance(resp.headers, odict.ODictCaseless) + + def test_equal(self): + a = tutils.tresp() + b = tutils.tresp() + assert a == b + + assert not a == 'foo' + assert not b == 'foo' + assert not 'foo' == a + assert not 'foo' == b + def test_repr(self): r = tutils.tresp() assert "unknown content type" in repr(r) diff --git a/test/test_encoding.py b/test/test_encoding.py index faf718ae6..612aea890 100644 --- a/test/test_encoding.py +++ b/test/test_encoding.py @@ -1,5 +1,6 @@ from netlib import encoding + def test_identity(): assert "string" == encoding.decode("identity", "string") assert "string" == encoding.encode("identity", "string") diff --git a/test/test_socks.py b/test/test_socks.py index 36fc5b3d4..3d109f428 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -44,7 +44,11 @@ def test_client_greeting_assert_socks5(): assert False raw = tutils.treader("XX") - tutils.raises(socks.SocksError, socks.ClientGreeting.from_file, raw, fail_early=True) + tutils.raises( + socks.SocksError, + socks.ClientGreeting.from_file, + raw, + fail_early=True) def test_server_greeting(): diff --git a/test/test_utils.py b/test/test_utils.py index 5e681eb62..fc7174d63 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,5 +1,3 @@ -import urlparse - from netlib import utils, odict, tutils @@ -30,8 +28,6 @@ def test_pretty_size(): assert utils.pretty_size(1024 * 1024) == "1MB" - - def test_parse_url(): assert not utils.parse_url("") @@ -86,7 +82,6 @@ def test_urlencode(): assert utils.urlencode([('foo', 'bar')]) - def test_urldecode(): s = "one=two&three=four" assert len(utils.urldecode(s)) == 2 @@ -101,3 +96,31 @@ def test_get_header_tokens(): assert utils.get_header_tokens(h, "foo") == ["bar", "voing"] h["foo"] = ["bar, voing", "oink"] assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + + +def test_multipartdecode(): + boundary = 'somefancyboundary' + headers = odict.ODict( + [('content-type', ('multipart/form-data; boundary=%s' % boundary))]) + content = "--{0}\n" \ + "Content-Disposition: form-data; name=\"field1\"\n\n" \ + "value1\n" \ + "--{0}\n" \ + "Content-Disposition: form-data; name=\"field2\"\n\n" \ + "value2\n" \ + "--{0}--".format(boundary) + + form = utils.multipartdecode(headers, content) + + assert len(form) == 2 + assert form[0] == ('field1', 'value1') + assert form[1] == ('field2', 'value2') + + +def test_parse_content_type(): + p = utils.parse_content_type + assert p("text/html") == ("text", "html", {}) + assert p("text") is None + + v = p("text/html; charset=UTF-8") + assert v == ('text', 'html', {'charset': 'UTF-8'}) diff --git a/test/tservers.py b/test/tservers.py index 3f3ea8b4b..682a9144e 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -3,7 +3,8 @@ import threading import Queue import cStringIO import OpenSSL -from netlib import tcp, certutils, tutils +from netlib import tcp +from netlib import tutils class ServerThread(threading.Thread): diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 28dbb8332..752f2c3ea 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -2,7 +2,10 @@ import os from nose.tools import raises -from netlib import tcp, http, websockets, tutils +from netlib import tcp +from netlib import tutils +from netlib import websockets +from netlib.http import status_codes from netlib.http.exceptions import * from netlib.http.http1 import HTTP1Protocol from .. import tservers @@ -38,7 +41,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): req = http1_protocol.read_request() key = self.protocol.check_client_handshake(req.headers) - preamble = http1_protocol.response_preamble(101) + preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers(key) self.wfile.write(headers.format() + "\r\n") @@ -62,7 +65,7 @@ class WebSocketsClient(tcp.TCPClient): http1_protocol = HTTP1Protocol(self) - preamble = http1_protocol.request_preamble("GET", "/") + preamble = 'GET / HTTP/1.1' self.wfile.write(preamble + "\r\n") headers = self.protocol.client_handshake_headers() self.client_nonce = headers.get_first("sec-websocket-key") @@ -162,7 +165,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler): client_hs = http1_protocol.read_request() self.protocol.check_client_handshake(client_hs.headers) - preamble = http1_protocol.response_preamble(101) + preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers("malformed key") self.wfile.write(headers.format() + "\r\n")