mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
Merge pull request #85 from Kriechi/http2-wip
add move tests and code from mitmproxy
This commit is contained in:
commit
f3a6113391
@ -2,7 +2,6 @@ from __future__ import (absolute_import, print_function, division)
|
||||
from argparse import Action, ArgumentTypeError
|
||||
import binascii
|
||||
|
||||
from .. import http
|
||||
|
||||
def parse_http_basic_auth(s):
|
||||
words = s.split()
|
||||
@ -37,7 +36,6 @@ class NullProxyAuth(object):
|
||||
"""
|
||||
Clean up authentication headers, so they're not passed upstream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def authenticate(self, headers_):
|
||||
"""
|
||||
|
@ -1,6 +1,8 @@
|
||||
from netlib import odict
|
||||
|
||||
|
||||
class HttpError(Exception):
|
||||
|
||||
def __init__(self, code, message):
|
||||
super(HttpError, self).__init__(message)
|
||||
self.code = code
|
||||
@ -11,6 +13,7 @@ class HttpErrorConnClosed(HttpError):
|
||||
|
||||
|
||||
class HttpAuthenticationError(Exception):
|
||||
|
||||
def __init__(self, auth_headers=None):
|
||||
super(HttpAuthenticationError, self).__init__(
|
||||
"Proxy Authentication Required"
|
||||
|
@ -1,28 +1,31 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
import binascii
|
||||
import collections
|
||||
import string
|
||||
import sys
|
||||
import urlparse
|
||||
import time
|
||||
|
||||
from netlib import odict, utils, tcp, http
|
||||
from netlib.http import semantics
|
||||
from .. import status_codes
|
||||
from ..exceptions import *
|
||||
|
||||
|
||||
class TCPHandler(object):
|
||||
|
||||
def __init__(self, rfile, wfile=None):
|
||||
self.rfile = rfile
|
||||
self.wfile = wfile
|
||||
|
||||
|
||||
class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
|
||||
def __init__(self, tcp_handler=None, rfile=None, wfile=None):
|
||||
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
|
||||
|
||||
|
||||
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
|
||||
def read_request(
|
||||
self,
|
||||
include_body=True,
|
||||
body_size_limit=None,
|
||||
allow_empty=False,
|
||||
):
|
||||
"""
|
||||
Parse an HTTP request from a file stream
|
||||
|
||||
@ -129,8 +132,12 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
timestamp_end,
|
||||
)
|
||||
|
||||
|
||||
def read_response(self, request_method, body_size_limit, include_body=True):
|
||||
def read_response(
|
||||
self,
|
||||
request_method,
|
||||
body_size_limit,
|
||||
include_body=True,
|
||||
):
|
||||
"""
|
||||
Returns an http.Response
|
||||
|
||||
@ -175,7 +182,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
# read separately
|
||||
body = None
|
||||
|
||||
|
||||
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
|
||||
@ -195,7 +201,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
timestamp_end=timestamp_end,
|
||||
)
|
||||
|
||||
|
||||
def assemble_request(self, request):
|
||||
assert isinstance(request, semantics.Request)
|
||||
|
||||
@ -208,7 +213,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
headers = self._assemble_request_headers(request)
|
||||
return "%s\r\n%s\r\n%s" % (first_line, headers, request.body)
|
||||
|
||||
|
||||
def assemble_response(self, response):
|
||||
assert isinstance(response, semantics.Response)
|
||||
|
||||
@ -221,7 +225,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
headers = self._assemble_response_headers(response)
|
||||
return "%s\r\n%s\r\n%s" % (first_line, headers, response.body)
|
||||
|
||||
|
||||
def read_headers(self):
|
||||
"""
|
||||
Read a set of headers.
|
||||
@ -266,7 +269,7 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
response_code,
|
||||
is_request,
|
||||
max_chunk_size=None
|
||||
):
|
||||
):
|
||||
"""
|
||||
Read an HTTP message body:
|
||||
headers: An ODictCaseless object
|
||||
@ -321,9 +324,14 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
"HTTP Body too large. Limit is %s," % limit
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def expected_http_body_size(self, headers, is_request, request_method, response_code):
|
||||
def expected_http_body_size(
|
||||
self,
|
||||
headers,
|
||||
is_request,
|
||||
request_method,
|
||||
response_code,
|
||||
):
|
||||
"""
|
||||
Returns the expected body length:
|
||||
- a positive integer, if the size is known in advance
|
||||
@ -359,20 +367,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
return -1
|
||||
|
||||
|
||||
@classmethod
|
||||
def request_preamble(self, method, resource, http_major="1", http_minor="1"):
|
||||
return '%s %s HTTP/%s.%s' % (
|
||||
method, resource, http_major, http_minor
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def response_preamble(self, code, message=None, http_major="1", http_minor="1"):
|
||||
if message is None:
|
||||
message = status_codes.RESPONSES.get(code)
|
||||
return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message)
|
||||
|
||||
|
||||
@classmethod
|
||||
def has_chunked_encoding(self, headers):
|
||||
return "chunked" in [
|
||||
@ -390,7 +384,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
line = self.tcp_handler.rfile.readline()
|
||||
return line
|
||||
|
||||
|
||||
def _read_chunked(self, limit, is_request):
|
||||
"""
|
||||
Read a chunked HTTP body.
|
||||
@ -427,7 +420,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
if length == 0:
|
||||
return
|
||||
|
||||
|
||||
@classmethod
|
||||
def _parse_http_protocol(self, line):
|
||||
"""
|
||||
@ -447,7 +439,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
return None
|
||||
return major, minor
|
||||
|
||||
|
||||
@classmethod
|
||||
def _parse_init(self, line):
|
||||
try:
|
||||
@ -461,7 +452,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
return None
|
||||
return method, url, httpversion
|
||||
|
||||
|
||||
@classmethod
|
||||
def _parse_init_connect(self, line):
|
||||
"""
|
||||
@ -489,7 +479,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
return None
|
||||
return host, port, httpversion
|
||||
|
||||
|
||||
@classmethod
|
||||
def _parse_init_proxy(self, line):
|
||||
v = self._parse_init(line)
|
||||
@ -503,7 +492,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
scheme, host, port, path = parts
|
||||
return method, scheme, host, port, path, httpversion
|
||||
|
||||
|
||||
@classmethod
|
||||
def _parse_init_http(self, line):
|
||||
"""
|
||||
@ -519,7 +507,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
return None
|
||||
return method, url, httpversion
|
||||
|
||||
|
||||
@classmethod
|
||||
def connection_close(self, httpversion, headers):
|
||||
"""
|
||||
@ -539,7 +526,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
# be persistent
|
||||
return httpversion != (1, 1)
|
||||
|
||||
|
||||
@classmethod
|
||||
def parse_response_line(self, line):
|
||||
parts = line.strip().split(" ", 2)
|
||||
@ -554,7 +540,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
return None
|
||||
return (proto, code, msg)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _assemble_request_first_line(self, request):
|
||||
return request.legacy_first_line()
|
||||
@ -575,7 +560,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
|
||||
return headers.format()
|
||||
|
||||
|
||||
def _assemble_response_first_line(self, response):
|
||||
return 'HTTP/%s.%s %s %s' % (
|
||||
response.httpversion[0],
|
||||
@ -584,7 +568,11 @@ class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
response.msg,
|
||||
)
|
||||
|
||||
def _assemble_response_headers(self, response, preserve_transfer_encoding=False):
|
||||
def _assemble_response_headers(
|
||||
self,
|
||||
response,
|
||||
preserve_transfer_encoding=False,
|
||||
):
|
||||
headers = response.headers.copy()
|
||||
for k in response._headers_to_strip_off:
|
||||
del headers[k]
|
||||
|
@ -117,7 +117,7 @@ class Frame(object):
|
||||
|
||||
return "\n".join([
|
||||
"%s: %s | length: %d | flags: %#x | stream_id: %d" % (
|
||||
direction, self.__class__.__name__, self.length, self.flags, self.stream_id),
|
||||
direction, self.__class__.__name__, self.length, self.flags, self.stream_id),
|
||||
self.payload_human_readable(),
|
||||
"===============================================================",
|
||||
])
|
||||
|
@ -9,6 +9,7 @@ from . import frame
|
||||
|
||||
|
||||
class TCPHandler(object):
|
||||
|
||||
def __init__(self, rfile, wfile=None):
|
||||
self.rfile = rfile
|
||||
self.wfile = wfile
|
||||
@ -39,7 +40,6 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
ALPN_PROTO_H2 = 'h2'
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tcp_handler=None,
|
||||
@ -60,7 +60,12 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
self.current_stream_id = None
|
||||
self.connection_preface_performed = False
|
||||
|
||||
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
|
||||
def read_request(
|
||||
self,
|
||||
include_body=True,
|
||||
body_size_limit=None,
|
||||
allow_empty=False,
|
||||
):
|
||||
self.perform_connection_preface()
|
||||
|
||||
timestamp_start = time.time()
|
||||
@ -92,7 +97,12 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
return request
|
||||
|
||||
def read_response(self, request_method='', body_size_limit=None, include_body=True):
|
||||
def read_response(
|
||||
self,
|
||||
request_method='',
|
||||
body_size_limit=None,
|
||||
include_body=True,
|
||||
):
|
||||
self.perform_connection_preface()
|
||||
|
||||
timestamp_start = time.time()
|
||||
@ -123,7 +133,6 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def assemble_request(self, request):
|
||||
assert isinstance(request, semantics.Request)
|
||||
|
||||
@ -133,13 +142,13 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
headers = request.headers.copy()
|
||||
|
||||
if not ':authority' in headers.keys():
|
||||
if ':authority' not in headers.keys():
|
||||
headers.add(':authority', bytes(authority), prepend=True)
|
||||
if not ':scheme' in headers.keys():
|
||||
if ':scheme' not in headers.keys():
|
||||
headers.add(':scheme', bytes(request.scheme), prepend=True)
|
||||
if not ':path' in headers.keys():
|
||||
if ':path' not in headers.keys():
|
||||
headers.add(':path', bytes(request.path), prepend=True)
|
||||
if not ':method' in headers.keys():
|
||||
if ':method' not in headers.keys():
|
||||
headers.add(':method', bytes(request.method), prepend=True)
|
||||
|
||||
headers = headers.items()
|
||||
@ -158,7 +167,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
headers = response.headers.copy()
|
||||
|
||||
if not ':status' in headers.keys():
|
||||
if ':status' not in headers.keys():
|
||||
headers.add(':status', bytes(str(response.status_code)), prepend=True)
|
||||
|
||||
headers = headers.items()
|
||||
|
@ -1,13 +1,9 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
import binascii
|
||||
import collections
|
||||
import string
|
||||
import sys
|
||||
import urllib
|
||||
import urlparse
|
||||
|
||||
from .. import utils, odict
|
||||
from . import cookies
|
||||
from . import cookies, exceptions
|
||||
from netlib import utils, encoding
|
||||
|
||||
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
|
||||
@ -18,11 +14,11 @@ CONTENT_MISSING = 0
|
||||
|
||||
class ProtocolMixin(object):
|
||||
|
||||
def read_request(self):
|
||||
raise NotImplemented
|
||||
def read_request(self, *args, **kwargs): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def read_response(self):
|
||||
raise NotImplemented
|
||||
def read_response(self, *args, **kwargs): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def assemble(self, message):
|
||||
if isinstance(message, Request):
|
||||
@ -32,14 +28,23 @@ class ProtocolMixin(object):
|
||||
else:
|
||||
raise ValueError("HTTP message not supported.")
|
||||
|
||||
def assemble_request(self, request):
|
||||
raise NotImplemented
|
||||
def assemble_request(self, *args, **kwargs): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def assemble_response(self, response):
|
||||
raise NotImplemented
|
||||
def assemble_response(self, *args, **kwargs): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Request(object):
|
||||
# This list is adopted legacy code.
|
||||
# We probably don't need to strip off keep-alive.
|
||||
_headers_to_strip_off = [
|
||||
'Proxy-Connection',
|
||||
'Keep-Alive',
|
||||
'Connection',
|
||||
'Transfer-Encoding',
|
||||
'Upgrade',
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -71,7 +76,6 @@ class Request(object):
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
@ -114,7 +118,7 @@ class Request(object):
|
||||
self.httpversion[1],
|
||||
)
|
||||
else:
|
||||
raise http.HttpError(400, "Invalid request form")
|
||||
raise exceptions.HttpError(400, "Invalid request form")
|
||||
|
||||
def anticache(self):
|
||||
"""
|
||||
@ -143,7 +147,7 @@ class Request(object):
|
||||
if self.headers["accept-encoding"]:
|
||||
self.headers["accept-encoding"] = [
|
||||
', '.join(
|
||||
e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])]
|
||||
e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))]
|
||||
|
||||
def update_host_header(self):
|
||||
"""
|
||||
@ -317,17 +321,18 @@ class Request(object):
|
||||
self.scheme, self.host, self.port, self.path = parts
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
def content(self): # pragma: no cover
|
||||
# TODO: remove deprecated getter
|
||||
return self.body
|
||||
|
||||
@content.setter
|
||||
def content(self, content):
|
||||
def content(self, content): # pragma: no cover
|
||||
# TODO: remove deprecated setter
|
||||
self.body = content
|
||||
|
||||
|
||||
class EmptyRequest(Request):
|
||||
|
||||
def __init__(self):
|
||||
super(EmptyRequest, self).__init__(
|
||||
form_in="",
|
||||
@ -339,10 +344,15 @@ class EmptyRequest(Request):
|
||||
httpversion=(0, 0),
|
||||
headers=odict.ODictCaseless(),
|
||||
body="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Response(object):
|
||||
_headers_to_strip_off = [
|
||||
'Proxy-Connection',
|
||||
'Alternate-Protocol',
|
||||
'Alt-Svc',
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -368,7 +378,6 @@ class Response(object):
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
@ -388,11 +397,9 @@ class Response(object):
|
||||
status_code=self.status_code,
|
||||
msg=self.msg,
|
||||
contenttype=self.headers.get_first(
|
||||
"content-type", "unknown content type"
|
||||
),
|
||||
size=size
|
||||
)
|
||||
|
||||
"content-type",
|
||||
"unknown content type"),
|
||||
size=size)
|
||||
|
||||
def get_cookies(self):
|
||||
"""
|
||||
@ -430,21 +437,21 @@ class Response(object):
|
||||
self.headers["Set-Cookie"] = values
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
def content(self): # pragma: no cover
|
||||
# TODO: remove deprecated getter
|
||||
return self.body
|
||||
|
||||
@content.setter
|
||||
def content(self, content):
|
||||
def content(self, content): # pragma: no cover
|
||||
# TODO: remove deprecated setter
|
||||
self.body = content
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
def code(self): # pragma: no cover
|
||||
# TODO: remove deprecated getter
|
||||
return self.status_code
|
||||
|
||||
@code.setter
|
||||
def code(self, code):
|
||||
def code(self, code): # pragma: no cover
|
||||
# TODO: remove deprecated setter
|
||||
self.status_code = code
|
||||
|
@ -91,8 +91,9 @@ class ODict(object):
|
||||
self.lst = self._filter_lst(k, self.lst)
|
||||
|
||||
def __contains__(self, k):
|
||||
k = self._kconv(k)
|
||||
for i in self.lst:
|
||||
if self._kconv(i[0]) == self._kconv(k):
|
||||
if self._kconv(i[0]) == k:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -69,8 +69,6 @@ def raises(exc, obj, *args, **kwargs):
|
||||
test_data = utils.Data(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
def treq(content="content", scheme="http", host="address", port=22):
|
||||
"""
|
||||
@return: libmproxy.protocol.http.HTTPRequest
|
||||
@ -119,7 +117,7 @@ def tresp(content="message"):
|
||||
"OK",
|
||||
headers,
|
||||
content,
|
||||
time.time(),
|
||||
time.time(),
|
||||
timestamp_start=time.time(),
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
return resp
|
||||
|
@ -4,6 +4,7 @@ import cgi
|
||||
import urllib
|
||||
import urlparse
|
||||
import string
|
||||
import re
|
||||
|
||||
|
||||
def isascii(s):
|
||||
@ -118,6 +119,7 @@ def pretty_size(size):
|
||||
|
||||
|
||||
class Data(object):
|
||||
|
||||
def __init__(self, name):
|
||||
m = __import__(name)
|
||||
dirname, _ = os.path.split(m.__file__)
|
||||
@ -136,8 +138,6 @@ class Data(object):
|
||||
return fullpath
|
||||
|
||||
|
||||
|
||||
|
||||
def is_valid_port(port):
|
||||
if not 0 <= port <= 65535:
|
||||
return False
|
||||
@ -220,6 +220,7 @@ def hostport(scheme, host, port):
|
||||
else:
|
||||
return "%s:%s" % (host, port)
|
||||
|
||||
|
||||
def unparse_url(scheme, host, port, path=""):
|
||||
"""
|
||||
Returns a URL string, constructed from the specified compnents.
|
||||
@ -234,8 +235,64 @@ def urlencode(s):
|
||||
s = [tuple(i) for i in s]
|
||||
return urllib.urlencode(s, False)
|
||||
|
||||
|
||||
def urldecode(s):
|
||||
"""
|
||||
Takes a urlencoded string and returns a list of (key, value) tuples.
|
||||
"""
|
||||
return cgi.parse_qsl(s, keep_blank_values=True)
|
||||
|
||||
|
||||
def parse_content_type(c):
|
||||
"""
|
||||
A simple parser for content-type values. Returns a (type, subtype,
|
||||
parameters) tuple, where type and subtype are strings, and parameters
|
||||
is a dict. If the string could not be parsed, return None.
|
||||
|
||||
E.g. the following string:
|
||||
|
||||
text/html; charset=UTF-8
|
||||
|
||||
Returns:
|
||||
|
||||
("text", "html", {"charset": "UTF-8"})
|
||||
"""
|
||||
parts = c.split(";", 1)
|
||||
ts = parts[0].split("/", 1)
|
||||
if len(ts) != 2:
|
||||
return None
|
||||
d = {}
|
||||
if len(parts) == 2:
|
||||
for i in parts[1].split(";"):
|
||||
clause = i.split("=", 1)
|
||||
if len(clause) == 2:
|
||||
d[clause[0].strip()] = clause[1].strip()
|
||||
return ts[0].lower(), ts[1].lower(), d
|
||||
|
||||
|
||||
def multipartdecode(hdrs, content):
|
||||
"""
|
||||
Takes a multipart boundary encoded string and returns list of (key, value) tuples.
|
||||
"""
|
||||
v = hdrs.get_first("content-type")
|
||||
if v:
|
||||
v = parse_content_type(v)
|
||||
if not v:
|
||||
return []
|
||||
boundary = v[2].get("boundary")
|
||||
if not boundary:
|
||||
return []
|
||||
|
||||
rx = re.compile(r'\bname="([^"]+)"')
|
||||
r = []
|
||||
|
||||
for i in content.split("--" + boundary):
|
||||
parts = i.splitlines()
|
||||
if len(parts) > 1 and parts[0][0:2] != "--":
|
||||
match = rx.search(parts[1])
|
||||
if match:
|
||||
key = match.group(1)
|
||||
value = "".join(parts[3 + parts[2:].index(""):])
|
||||
r.append((key, value))
|
||||
return r
|
||||
return []
|
||||
|
@ -1,12 +1,11 @@
|
||||
from __future__ import absolute_import
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import struct
|
||||
import io
|
||||
|
||||
from .protocol import Masker
|
||||
from netlib import utils, odict, tcp
|
||||
from netlib import tcp
|
||||
from netlib import utils
|
||||
|
||||
DEFAULT = object()
|
||||
|
||||
@ -22,6 +21,7 @@ OPCODE = utils.BiDi(
|
||||
PONG=0x0a
|
||||
)
|
||||
|
||||
|
||||
class FrameHeader(object):
|
||||
|
||||
def __init__(
|
||||
|
@ -2,10 +2,9 @@ from __future__ import absolute_import
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import struct
|
||||
import io
|
||||
|
||||
from netlib import utils, odict, tcp
|
||||
from netlib import odict
|
||||
from netlib import utils
|
||||
|
||||
# Colleciton of utility functions that implement small portions of the RFC6455
|
||||
# WebSockets Protocol Useful for building WebSocket clients and servers.
|
||||
@ -26,6 +25,7 @@ HEADER_WEBSOCKET_KEY = 'sec-websocket-key'
|
||||
HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept'
|
||||
HEADER_WEBSOCKET_VERSION = 'sec-websocket-version'
|
||||
|
||||
|
||||
class Masker(object):
|
||||
|
||||
"""
|
||||
@ -53,6 +53,7 @@ class Masker(object):
|
||||
self.offset += len(ret)
|
||||
return ret
|
||||
|
||||
|
||||
class WebsocketsProtocol(object):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -1,18 +1,41 @@
|
||||
import cStringIO
|
||||
import textwrap
|
||||
import binascii
|
||||
|
||||
from netlib import http, odict, tcp, tutils
|
||||
from netlib.http import semantics
|
||||
from netlib.http.http1 import HTTP1Protocol
|
||||
from ... import tservers
|
||||
|
||||
|
||||
def mock_protocol(data='', chunked=False):
|
||||
class NoContentLengthHTTPHandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n")
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
def mock_protocol(data=''):
|
||||
rfile = cStringIO.StringIO(data)
|
||||
wfile = cStringIO.StringIO()
|
||||
return HTTP1Protocol(rfile=rfile, wfile=wfile)
|
||||
|
||||
|
||||
def match_http_string(data):
|
||||
return textwrap.dedent(data).strip().replace('\n', '\r\n')
|
||||
|
||||
|
||||
def test_stripped_chunked_encoding_no_content():
|
||||
"""
|
||||
https://github.com/mitmproxy/mitmproxy/issues/186
|
||||
"""
|
||||
|
||||
r = tutils.treq(content="")
|
||||
r.headers["Transfer-Encoding"] = ["chunked"]
|
||||
assert "Content-Length" in mock_protocol()._assemble_request_headers(r)
|
||||
|
||||
r = tutils.tresp(content="")
|
||||
r.headers["Transfer-Encoding"] = ["chunked"]
|
||||
assert "Content-Length" in mock_protocol()._assemble_response_headers(r)
|
||||
|
||||
|
||||
def test_has_chunked_encoding():
|
||||
h = odict.ODictCaseless()
|
||||
@ -75,7 +98,6 @@ def test_connection_close():
|
||||
assert HTTP1Protocol.connection_close((1, 1), h)
|
||||
|
||||
|
||||
|
||||
def test_read_http_body_request():
|
||||
h = odict.ODictCaseless()
|
||||
data = "testing"
|
||||
@ -85,7 +107,7 @@ def test_read_http_body_request():
|
||||
def test_read_http_body_response():
|
||||
h = odict.ODictCaseless()
|
||||
data = "testing"
|
||||
assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing"
|
||||
assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing"
|
||||
|
||||
|
||||
def test_read_http_body():
|
||||
@ -129,13 +151,13 @@ def test_read_http_body():
|
||||
# test no content length: limit > actual content
|
||||
h = odict.ODictCaseless()
|
||||
data = "testing"
|
||||
assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7
|
||||
assert len(mock_protocol(data).read_http_body(h, 100, "GET", 200, False)) == 7
|
||||
|
||||
# test no content length: limit < actual content
|
||||
data = "testing"
|
||||
tutils.raises(
|
||||
http.HttpError,
|
||||
mock_protocol(data, chunked=True).read_http_body,
|
||||
mock_protocol(data).read_http_body,
|
||||
h, 4, "GET", 200, False
|
||||
)
|
||||
|
||||
@ -143,7 +165,7 @@ def test_read_http_body():
|
||||
h = odict.ODictCaseless()
|
||||
h["transfer-encoding"] = ["chunked"]
|
||||
data = "5\r\naaaaa\r\n0\r\n\r\n"
|
||||
assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa"
|
||||
assert mock_protocol(data).read_http_body(h, 100, "GET", 200, False) == "aaaaa"
|
||||
|
||||
|
||||
def test_expected_http_body_size():
|
||||
@ -167,6 +189,13 @@ def test_expected_http_body_size():
|
||||
assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0
|
||||
|
||||
|
||||
def test_get_request_line():
|
||||
data = "\nfoo"
|
||||
p = mock_protocol(data)
|
||||
assert p._get_request_line() == "foo"
|
||||
assert not p._get_request_line()
|
||||
|
||||
|
||||
def test_parse_http_protocol():
|
||||
assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1)
|
||||
assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0)
|
||||
@ -269,96 +298,7 @@ class TestReadHeaders:
|
||||
assert self._read(data) is None
|
||||
|
||||
|
||||
class NoContentLengthHTTPHandler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n")
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class TestReadResponseNoContentLength(tservers.ServerTestBase):
|
||||
handler = NoContentLengthHTTPHandler
|
||||
|
||||
def test_no_content_length(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
resp = HTTP1Protocol(c).read_response("GET", None)
|
||||
assert resp.body == "bar\r\n\r\n"
|
||||
|
||||
|
||||
def test_read_response():
|
||||
def tst(data, method, body_size_limit, include_body=True):
|
||||
data = textwrap.dedent(data)
|
||||
return mock_protocol(data).read_response(
|
||||
method, body_size_limit, include_body=include_body
|
||||
)
|
||||
|
||||
tutils.raises("server disconnect", tst, "", "GET", None)
|
||||
tutils.raises("invalid server response", tst, "foo", "GET", None)
|
||||
data = """
|
||||
HTTP/1.1 200 OK
|
||||
"""
|
||||
assert tst(data, "GET", None) == http.Response(
|
||||
(1, 1), 200, 'OK', odict.ODictCaseless(), ''
|
||||
)
|
||||
data = """
|
||||
HTTP/1.1 200
|
||||
"""
|
||||
assert tst(data, "GET", None) == http.Response(
|
||||
(1, 1), 200, '', odict.ODictCaseless(), ''
|
||||
)
|
||||
data = """
|
||||
HTTP/x 200 OK
|
||||
"""
|
||||
tutils.raises("invalid http version", tst, data, "GET", None)
|
||||
data = """
|
||||
HTTP/1.1 xx OK
|
||||
"""
|
||||
tutils.raises("invalid server response", tst, data, "GET", None)
|
||||
|
||||
data = """
|
||||
HTTP/1.1 100 CONTINUE
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
"""
|
||||
assert tst(data, "GET", None) == http.Response(
|
||||
(1, 1), 100, 'CONTINUE', odict.ODictCaseless(), ''
|
||||
)
|
||||
|
||||
data = """
|
||||
HTTP/1.1 200 OK
|
||||
Content-Length: 3
|
||||
|
||||
foo
|
||||
"""
|
||||
assert tst(data, "GET", None).body == 'foo'
|
||||
assert tst(data, "HEAD", None).body == ''
|
||||
|
||||
data = """
|
||||
HTTP/1.1 200 OK
|
||||
\tContent-Length: 3
|
||||
|
||||
foo
|
||||
"""
|
||||
tutils.raises("invalid headers", tst, data, "GET", None)
|
||||
|
||||
data = """
|
||||
HTTP/1.1 200 OK
|
||||
Content-Length: 3
|
||||
|
||||
foo
|
||||
"""
|
||||
assert tst(data, "GET", None, include_body=False).body is None
|
||||
|
||||
|
||||
def test_get_request_line():
|
||||
data = "\nfoo"
|
||||
p = mock_protocol(data)
|
||||
assert p._get_request_line() == "foo"
|
||||
assert not p._get_request_line()
|
||||
|
||||
|
||||
class TestReadRequest():
|
||||
class TestReadRequest(object):
|
||||
|
||||
def tst(self, data, **kwargs):
|
||||
return mock_protocol(data).read_request(**kwargs)
|
||||
@ -385,6 +325,10 @@ class TestReadRequest():
|
||||
"\r\n"
|
||||
)
|
||||
|
||||
def test_empty(self):
|
||||
v = self.tst("", allow_empty=True)
|
||||
assert isinstance(v, semantics.EmptyRequest)
|
||||
|
||||
def test_asterisk_form_in(self):
|
||||
v = self.tst("OPTIONS * HTTP/1.1")
|
||||
assert v.form_in == "relative"
|
||||
@ -427,3 +371,131 @@ class TestReadRequest():
|
||||
assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n"
|
||||
assert v.body == "foo"
|
||||
assert p.tcp_handler.rfile.read(3) == "bar"
|
||||
|
||||
|
||||
class TestReadResponse(object):
|
||||
def tst(self, data, method, body_size_limit, include_body=True):
|
||||
data = textwrap.dedent(data)
|
||||
return mock_protocol(data).read_response(
|
||||
method, body_size_limit, include_body=include_body
|
||||
)
|
||||
|
||||
def test_errors(self):
|
||||
tutils.raises("server disconnect", self.tst, "", "GET", None)
|
||||
tutils.raises("invalid server response", self.tst, "foo", "GET", None)
|
||||
|
||||
def test_simple(self):
|
||||
data = """
|
||||
HTTP/1.1 200
|
||||
"""
|
||||
assert self.tst(data, "GET", None) == http.Response(
|
||||
(1, 1), 200, '', odict.ODictCaseless(), ''
|
||||
)
|
||||
|
||||
def test_simple_message(self):
|
||||
data = """
|
||||
HTTP/1.1 200 OK
|
||||
"""
|
||||
assert self.tst(data, "GET", None) == http.Response(
|
||||
(1, 1), 200, 'OK', odict.ODictCaseless(), ''
|
||||
)
|
||||
|
||||
def test_invalid_http_version(self):
|
||||
data = """
|
||||
HTTP/x 200 OK
|
||||
"""
|
||||
tutils.raises("invalid http version", self.tst, data, "GET", None)
|
||||
|
||||
def test_invalid_status_code(self):
|
||||
data = """
|
||||
HTTP/1.1 xx OK
|
||||
"""
|
||||
tutils.raises("invalid server response", self.tst, data, "GET", None)
|
||||
|
||||
def test_valid_with_continue(self):
|
||||
data = """
|
||||
HTTP/1.1 100 CONTINUE
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
"""
|
||||
assert self.tst(data, "GET", None) == http.Response(
|
||||
(1, 1), 100, 'CONTINUE', odict.ODictCaseless(), ''
|
||||
)
|
||||
|
||||
def test_simple_body(self):
|
||||
data = """
|
||||
HTTP/1.1 200 OK
|
||||
Content-Length: 3
|
||||
|
||||
foo
|
||||
"""
|
||||
assert self.tst(data, "GET", None).body == 'foo'
|
||||
assert self.tst(data, "HEAD", None).body == ''
|
||||
|
||||
def test_invalid_headers(self):
|
||||
data = """
|
||||
HTTP/1.1 200 OK
|
||||
\tContent-Length: 3
|
||||
|
||||
foo
|
||||
"""
|
||||
tutils.raises("invalid headers", self.tst, data, "GET", None)
|
||||
|
||||
def test_without_body(self):
|
||||
data = """
|
||||
HTTP/1.1 200 OK
|
||||
Content-Length: 3
|
||||
|
||||
foo
|
||||
"""
|
||||
assert self.tst(data, "GET", None, include_body=False).body is None
|
||||
|
||||
|
||||
class TestReadResponseNoContentLength(tservers.ServerTestBase):
|
||||
handler = NoContentLengthHTTPHandler
|
||||
|
||||
def test_no_content_length(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
resp = HTTP1Protocol(c).read_response("GET", None)
|
||||
assert resp.body == "bar\r\n\r\n"
|
||||
|
||||
|
||||
class TestAssembleRequest(object):
|
||||
def test_simple(self):
|
||||
req = tutils.treq()
|
||||
b = HTTP1Protocol().assemble_request(req)
|
||||
assert b == match_http_string("""
|
||||
GET /path HTTP/1.1
|
||||
header: qvalue
|
||||
Host: address:22
|
||||
Content-Length: 7
|
||||
|
||||
content""")
|
||||
|
||||
def test_body_missing(self):
|
||||
req = tutils.treq(content=semantics.CONTENT_MISSING)
|
||||
tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req)
|
||||
|
||||
def test_not_a_request(self):
|
||||
tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo')
|
||||
|
||||
|
||||
class TestAssembleResponse(object):
|
||||
def test_simple(self):
|
||||
resp = tutils.tresp()
|
||||
b = HTTP1Protocol().assemble_response(resp)
|
||||
print(b)
|
||||
assert b == match_http_string("""
|
||||
HTTP/1.1 200 OK
|
||||
header_response: svalue
|
||||
Content-Length: 7
|
||||
|
||||
message""")
|
||||
|
||||
def test_body_missing(self):
|
||||
resp = tutils.tresp(content=semantics.CONTENT_MISSING)
|
||||
tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp)
|
||||
|
||||
def test_not_a_request(self):
|
||||
tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo')
|
||||
|
@ -1,10 +1,25 @@
|
||||
import OpenSSL
|
||||
import mock
|
||||
|
||||
from netlib import tcp, odict, http, tutils
|
||||
from netlib.http import http2
|
||||
from netlib.http.http2 import HTTP2Protocol
|
||||
from netlib.http.http2.frame import *
|
||||
from ... import tservers
|
||||
|
||||
class TestTCPHandlerWrapper:
|
||||
def test_wrapped(self):
|
||||
h = http2.TCPHandler(rfile='foo', wfile='bar')
|
||||
p = HTTP2Protocol(h)
|
||||
assert p.tcp_handler.rfile == 'foo'
|
||||
assert p.tcp_handler.wfile == 'bar'
|
||||
|
||||
def test_direct(self):
|
||||
p = HTTP2Protocol(rfile='foo', wfile='bar')
|
||||
assert isinstance(p.tcp_handler, http2.TCPHandler)
|
||||
assert p.tcp_handler.rfile == 'foo'
|
||||
assert p.tcp_handler.wfile == 'bar'
|
||||
|
||||
|
||||
class EchoHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
@ -16,10 +31,40 @@ class EchoHandler(tcp.BaseHandler):
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class TestProtocol:
|
||||
@mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
|
||||
@mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
|
||||
def test_perform_connection_preface(self, mock_client_method, mock_server_method):
|
||||
protocol = HTTP2Protocol(is_server=False)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
protocol.perform_connection_preface()
|
||||
assert not mock_client_method.called
|
||||
assert not mock_server_method.called
|
||||
|
||||
protocol.perform_connection_preface(force=True)
|
||||
assert mock_client_method.called
|
||||
assert not mock_server_method.called
|
||||
|
||||
@mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
|
||||
@mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
|
||||
def test_perform_connection_preface_server(self, mock_client_method, mock_server_method):
|
||||
protocol = HTTP2Protocol(is_server=True)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
protocol.perform_connection_preface()
|
||||
assert not mock_client_method.called
|
||||
assert not mock_server_method.called
|
||||
|
||||
protocol.perform_connection_preface(force=True)
|
||||
assert not mock_client_method.called
|
||||
assert mock_server_method.called
|
||||
|
||||
|
||||
class TestCheckALPNMatch(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
ssl = dict(
|
||||
alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2,
|
||||
alpn_select=HTTP2Protocol.ALPN_PROTO_H2,
|
||||
)
|
||||
|
||||
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
|
||||
@ -27,8 +72,8 @@ class TestCheckALPNMatch(tservers.ServerTestBase):
|
||||
def test_check_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2])
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2])
|
||||
protocol = HTTP2Protocol(c)
|
||||
assert protocol.check_alpn()
|
||||
|
||||
|
||||
@ -43,8 +88,8 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):
|
||||
def test_check_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2])
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2])
|
||||
protocol = HTTP2Protocol(c)
|
||||
tutils.raises(NotImplementedError, protocol.check_alpn)
|
||||
|
||||
|
||||
@ -76,8 +121,13 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
|
||||
def test_perform_server_connection_preface(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol = HTTP2Protocol(c)
|
||||
|
||||
assert not protocol.connection_preface_performed
|
||||
protocol.perform_server_connection_preface()
|
||||
assert protocol.connection_preface_performed
|
||||
|
||||
tutils.raises(tcp.NetLibIncomplete, protocol.perform_server_connection_preface, force=True)
|
||||
|
||||
|
||||
class TestPerformClientConnectionPreface(tservers.ServerTestBase):
|
||||
@ -107,13 +157,16 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
|
||||
def test_perform_client_connection_preface(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol = HTTP2Protocol(c)
|
||||
|
||||
assert not protocol.connection_preface_performed
|
||||
protocol.perform_client_connection_preface()
|
||||
assert protocol.connection_preface_performed
|
||||
|
||||
|
||||
class TestClientStreamIds():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol = HTTP2Protocol(c)
|
||||
|
||||
def test_client_stream_ids(self):
|
||||
assert self.protocol.current_stream_id is None
|
||||
@ -127,7 +180,7 @@ class TestClientStreamIds():
|
||||
|
||||
class TestServerStreamIds():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
protocol = http2.HTTP2Protocol(c, is_server=True)
|
||||
protocol = HTTP2Protocol(c, is_server=True)
|
||||
|
||||
def test_server_stream_ids(self):
|
||||
assert self.protocol.current_stream_id is None
|
||||
@ -154,7 +207,7 @@ class TestApplySettings(tservers.ServerTestBase):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol = HTTP2Protocol(c)
|
||||
|
||||
protocol._apply_settings({
|
||||
SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo',
|
||||
@ -182,13 +235,13 @@ class TestCreateHeaders():
|
||||
(b':scheme', b'https'),
|
||||
(b'foo', b'bar')]
|
||||
|
||||
bytes = http2.HTTP2Protocol(self.c)._create_headers(
|
||||
bytes = HTTP2Protocol(self.c)._create_headers(
|
||||
headers, 1, end_stream=True)
|
||||
assert b''.join(bytes) ==\
|
||||
'000014010500000001824488355217caf3a69a3f87408294e7838c767f'\
|
||||
.decode('hex')
|
||||
|
||||
bytes = http2.HTTP2Protocol(self.c)._create_headers(
|
||||
bytes = HTTP2Protocol(self.c)._create_headers(
|
||||
headers, 1, end_stream=False)
|
||||
assert b''.join(bytes) ==\
|
||||
'000014010400000001824488355217caf3a69a3f87408294e7838c767f'\
|
||||
@ -199,7 +252,7 @@ class TestCreateHeaders():
|
||||
|
||||
class TestCreateBody():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol = HTTP2Protocol(c)
|
||||
|
||||
def test_create_body_empty(self):
|
||||
bytes = self.protocol._create_body(b'', 1)
|
||||
@ -215,98 +268,6 @@ class TestCreateBody():
|
||||
# TODO: add test for too large frames
|
||||
|
||||
|
||||
class TestAssembleRequest():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
|
||||
def test_assemble_request_simple(self):
|
||||
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
|
||||
'',
|
||||
'GET',
|
||||
'https',
|
||||
'',
|
||||
'',
|
||||
'/',
|
||||
(2, 0),
|
||||
None,
|
||||
None,
|
||||
))
|
||||
assert len(bytes) == 1
|
||||
assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex')
|
||||
|
||||
def test_assemble_request_with_body(self):
|
||||
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
|
||||
'',
|
||||
'GET',
|
||||
'https',
|
||||
'',
|
||||
'',
|
||||
'/',
|
||||
(2, 0),
|
||||
odict.ODictCaseless([('foo', 'bar')]),
|
||||
'foobar',
|
||||
))
|
||||
assert len(bytes) == 2
|
||||
assert bytes[0] ==\
|
||||
'0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex')
|
||||
assert bytes[1] ==\
|
||||
'000006000100000001666f6f626172'.decode('hex')
|
||||
|
||||
|
||||
class TestReadResponse(tservers.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(
|
||||
b'00000801040000000188628594e78c767f'.decode('hex'))
|
||||
self.wfile.write(
|
||||
b'000006000100000001666f6f626172'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_read_response(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
assert resp.body == b'foobar'
|
||||
|
||||
|
||||
class TestReadEmptyResponse(tservers.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(
|
||||
b'00000801050000000188628594e78c767f'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_read_empty_response(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.stream_id
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
assert resp.body == b''
|
||||
|
||||
|
||||
class TestReadRequest(tservers.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
@ -323,7 +284,7 @@ class TestReadRequest(tservers.ServerTestBase):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c, is_server=True)
|
||||
protocol = HTTP2Protocol(c, is_server=True)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_request()
|
||||
@ -333,11 +294,138 @@ class TestReadRequest(tservers.ServerTestBase):
|
||||
assert resp.body == b'foobar'
|
||||
|
||||
|
||||
class TestCreateResponse():
|
||||
class TestReadResponse(tservers.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(
|
||||
b'00000801040000000188628594e78c767f'.decode('hex'))
|
||||
self.wfile.write(
|
||||
b'000006000100000001666f6f626172'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_read_response(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = HTTP2Protocol(c)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
assert resp.body == b'foobar'
|
||||
assert resp.timestamp_end
|
||||
|
||||
def test_read_response_no_body(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = HTTP2Protocol(c)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_response(include_body=False)
|
||||
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
assert resp.body == b'foobar' # TODO: this should be true: assert resp.body == http.CONTENT_MISSING
|
||||
assert not resp.timestamp_end
|
||||
|
||||
|
||||
class TestReadEmptyResponse(tservers.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(
|
||||
b'00000801050000000188628594e78c767f'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_read_empty_response(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = HTTP2Protocol(c)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.stream_id
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
assert resp.body == b''
|
||||
|
||||
|
||||
class TestAssembleRequest(object):
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
|
||||
def test_create_response_simple(self):
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
|
||||
def test_request_simple(self):
|
||||
bytes = HTTP2Protocol(self.c).assemble_request(http.Request(
|
||||
'',
|
||||
'GET',
|
||||
'https',
|
||||
'',
|
||||
'',
|
||||
'/',
|
||||
(2, 0),
|
||||
None,
|
||||
None,
|
||||
))
|
||||
assert len(bytes) == 1
|
||||
assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex')
|
||||
|
||||
def test_request_with_stream_id(self):
|
||||
req = http.Request(
|
||||
'',
|
||||
'GET',
|
||||
'https',
|
||||
'',
|
||||
'',
|
||||
'/',
|
||||
(2, 0),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
req.stream_id = 0x42
|
||||
bytes = HTTP2Protocol(self.c).assemble_request(req)
|
||||
assert len(bytes) == 1
|
||||
print(bytes[0].encode('hex'))
|
||||
assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex')
|
||||
|
||||
def test_request_with_body(self):
|
||||
bytes = HTTP2Protocol(self.c).assemble_request(http.Request(
|
||||
'',
|
||||
'GET',
|
||||
'https',
|
||||
'',
|
||||
'',
|
||||
'/',
|
||||
(2, 0),
|
||||
odict.ODictCaseless([('foo', 'bar')]),
|
||||
'foobar',
|
||||
))
|
||||
assert len(bytes) == 2
|
||||
assert bytes[0] ==\
|
||||
'0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex')
|
||||
assert bytes[1] ==\
|
||||
'000006000100000001666f6f626172'.decode('hex')
|
||||
|
||||
|
||||
class TestAssembleResponse(object):
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
|
||||
def test_simple(self):
|
||||
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
|
||||
(2, 0),
|
||||
200,
|
||||
))
|
||||
@ -345,8 +433,19 @@ class TestCreateResponse():
|
||||
assert bytes[0] ==\
|
||||
'00000101050000000288'.decode('hex')
|
||||
|
||||
def test_create_response_with_body(self):
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
|
||||
def test_with_stream_id(self):
|
||||
resp = http.Response(
|
||||
(2, 0),
|
||||
200,
|
||||
)
|
||||
resp.stream_id = 0x42
|
||||
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp)
|
||||
assert len(bytes) == 1
|
||||
assert bytes[0] ==\
|
||||
'00000101050000004288'.decode('hex')
|
||||
|
||||
def test_with_body(self):
|
||||
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
|
||||
(2, 0),
|
||||
200,
|
||||
'',
|
||||
|
@ -1,6 +1,27 @@
|
||||
from netlib.http.exceptions import *
|
||||
from netlib import odict
|
||||
|
||||
def test_HttpAuthenticationError():
|
||||
x = HttpAuthenticationError({"foo": "bar"})
|
||||
assert str(x)
|
||||
assert "foo" in x.headers
|
||||
class TestHttpError:
|
||||
def test_simple(self):
|
||||
e = HttpError(404, "Not found")
|
||||
assert str(e)
|
||||
|
||||
class TestHttpAuthenticationError:
|
||||
def test_init(self):
|
||||
headers = odict.ODictCaseless([("foo", "bar")])
|
||||
x = HttpAuthenticationError(headers)
|
||||
assert str(x)
|
||||
assert isinstance(x.headers, odict.ODictCaseless)
|
||||
assert x.code == 407
|
||||
assert x.headers == headers
|
||||
print(x.headers.keys())
|
||||
assert "foo" in x.headers.keys()
|
||||
|
||||
def test_header_conversion(self):
|
||||
headers = {"foo": "bar"}
|
||||
x = HttpAuthenticationError(headers)
|
||||
assert isinstance(x.headers, odict.ODictCaseless)
|
||||
assert x.headers.lst == headers.items()
|
||||
|
||||
def test_repr(self):
|
||||
assert repr(HttpAuthenticationError()) == "Proxy Authentication Required"
|
||||
|
@ -1,18 +1,275 @@
|
||||
import cStringIO
|
||||
import textwrap
|
||||
import binascii
|
||||
from mock import MagicMock
|
||||
import mock
|
||||
|
||||
from netlib import http, odict, tcp, tutils
|
||||
from netlib.http import http1
|
||||
from netlib import http
|
||||
from netlib import odict
|
||||
from netlib import tutils
|
||||
from netlib import utils
|
||||
from netlib.http import semantics
|
||||
from netlib.http.semantics import CONTENT_MISSING
|
||||
from .. import tservers
|
||||
|
||||
def test_httperror():
|
||||
e = http.exceptions.HttpError(404, "Not found")
|
||||
assert str(e)
|
||||
class TestProtocolMixin(object):
|
||||
@mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
|
||||
@mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
|
||||
def test_assemble_request(self, mock_request_method, mock_response_method):
|
||||
p = semantics.ProtocolMixin()
|
||||
p.assemble(tutils.treq())
|
||||
assert mock_request_method.called
|
||||
assert not mock_response_method.called
|
||||
|
||||
@mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
|
||||
@mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
|
||||
def test_assemble_response(self, mock_request_method, mock_response_method):
|
||||
p = semantics.ProtocolMixin()
|
||||
p.assemble(tutils.tresp())
|
||||
assert not mock_request_method.called
|
||||
assert mock_response_method.called
|
||||
|
||||
def test_assemble_foo(self):
|
||||
p = semantics.ProtocolMixin()
|
||||
tutils.raises(ValueError, p.assemble, 'foo')
|
||||
|
||||
class TestRequest(object):
|
||||
def test_repr(self):
|
||||
r = tutils.treq()
|
||||
assert repr(r)
|
||||
|
||||
def test_headers_odict(self):
|
||||
tutils.raises(AssertionError, semantics.Request,
|
||||
'form_in',
|
||||
'method',
|
||||
'scheme',
|
||||
'host',
|
||||
'port',
|
||||
'path',
|
||||
(1, 1),
|
||||
'foobar',
|
||||
)
|
||||
|
||||
req = semantics.Request(
|
||||
'form_in',
|
||||
'method',
|
||||
'scheme',
|
||||
'host',
|
||||
'port',
|
||||
'path',
|
||||
(1, 1),
|
||||
)
|
||||
assert isinstance(req.headers, odict.ODictCaseless)
|
||||
|
||||
def test_equal(self):
|
||||
a = tutils.treq()
|
||||
b = tutils.treq()
|
||||
assert a == b
|
||||
|
||||
assert not a == 'foo'
|
||||
assert not b == 'foo'
|
||||
assert not 'foo' == a
|
||||
assert not 'foo' == b
|
||||
|
||||
def test_legacy_first_line(self):
|
||||
req = tutils.treq()
|
||||
|
||||
req.form_in = 'relative'
|
||||
assert req.legacy_first_line() == "GET /path HTTP/1.1"
|
||||
|
||||
req.form_in = 'authority'
|
||||
assert req.legacy_first_line() == "GET address:22 HTTP/1.1"
|
||||
|
||||
req.form_in = 'absolute'
|
||||
assert req.legacy_first_line() == "GET http://address:22/path HTTP/1.1"
|
||||
|
||||
req.form_in = 'foobar'
|
||||
tutils.raises(http.HttpError, req.legacy_first_line)
|
||||
|
||||
def test_anticache(self):
|
||||
req = tutils.treq()
|
||||
req.headers.add("If-Modified-Since", "foo")
|
||||
req.headers.add("If-None-Match", "bar")
|
||||
req.anticache()
|
||||
assert "If-Modified-Since" not in req.headers
|
||||
assert "If-None-Match" not in req.headers
|
||||
|
||||
def test_anticomp(self):
|
||||
req = tutils.treq()
|
||||
req.headers.add("Accept-Encoding", "foobar")
|
||||
req.anticomp()
|
||||
assert req.headers["Accept-Encoding"] == ["identity"]
|
||||
|
||||
def test_constrain_encoding(self):
|
||||
req = tutils.treq()
|
||||
req.headers.add("Accept-Encoding", "identity, gzip, foo")
|
||||
req.constrain_encoding()
|
||||
assert "foo" not in req.headers.get_first("Accept-Encoding")
|
||||
|
||||
def test_update_host(self):
|
||||
req = tutils.treq()
|
||||
req.headers.add("Host", "")
|
||||
req.host = "foobar"
|
||||
req.update_host_header()
|
||||
assert req.headers.get_first("Host") == "foobar"
|
||||
|
||||
def test_get_form(self):
|
||||
req = tutils.treq()
|
||||
assert req.get_form() == odict.ODict()
|
||||
|
||||
@mock.patch("netlib.http.semantics.Request.get_form_multipart")
|
||||
@mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
|
||||
def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart):
|
||||
req = tutils.treq()
|
||||
assert req.get_form() == odict.ODict()
|
||||
|
||||
req = tutils.treq()
|
||||
req.body = "foobar"
|
||||
req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED]
|
||||
req.get_form()
|
||||
assert req.get_form_urlencoded.called
|
||||
assert not req.get_form_multipart.called
|
||||
|
||||
@mock.patch("netlib.http.semantics.Request.get_form_multipart")
|
||||
@mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
|
||||
def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart):
|
||||
req = tutils.treq()
|
||||
req.body = "foobar"
|
||||
req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART]
|
||||
req.get_form()
|
||||
assert not req.get_form_urlencoded.called
|
||||
assert req.get_form_multipart.called
|
||||
|
||||
def test_get_form_urlencoded(self):
|
||||
req = tutils.treq("foobar")
|
||||
assert req.get_form_urlencoded() == odict.ODict()
|
||||
|
||||
req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED]
|
||||
assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body))
|
||||
|
||||
def test_get_form_multipart(self):
|
||||
req = tutils.treq("foobar")
|
||||
assert req.get_form_multipart() == odict.ODict()
|
||||
|
||||
req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART]
|
||||
assert req.get_form_multipart() == odict.ODict(
|
||||
utils.multipartdecode(
|
||||
req.headers,
|
||||
req.body))
|
||||
|
||||
def test_set_form_urlencoded(self):
|
||||
req = tutils.treq()
|
||||
req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')]))
|
||||
assert req.headers.get_first("Content-Type") == semantics.HDR_FORM_URLENCODED
|
||||
assert req.body
|
||||
|
||||
def test_get_path_components(self):
|
||||
req = tutils.treq()
|
||||
assert req.get_path_components()
|
||||
# TODO: add meaningful assertions
|
||||
|
||||
def test_set_path_components(self):
|
||||
req = tutils.treq()
|
||||
req.set_path_components(["foo", "bar"])
|
||||
# TODO: add meaningful assertions
|
||||
|
||||
def test_get_query(self):
|
||||
req = tutils.treq()
|
||||
assert req.get_query().lst == []
|
||||
|
||||
req.url = "http://localhost:80/foo?bar=42"
|
||||
assert req.get_query().lst == [("bar", "42")]
|
||||
|
||||
def test_set_query(self):
|
||||
req = tutils.treq()
|
||||
req.set_query(odict.ODict([]))
|
||||
|
||||
def test_pretty_host(self):
|
||||
r = tutils.treq()
|
||||
assert r.pretty_host(True) == "address"
|
||||
assert r.pretty_host(False) == "address"
|
||||
r.headers["host"] = ["other"]
|
||||
assert r.pretty_host(True) == "other"
|
||||
assert r.pretty_host(False) == "address"
|
||||
r.host = None
|
||||
assert r.pretty_host(True) == "other"
|
||||
assert r.pretty_host(False) is None
|
||||
del r.headers["host"]
|
||||
assert r.pretty_host(True) is None
|
||||
assert r.pretty_host(False) is None
|
||||
|
||||
# Invalid IDNA
|
||||
r.headers["host"] = [".disqus.com"]
|
||||
assert r.pretty_host(True) == ".disqus.com"
|
||||
|
||||
def test_pretty_url(self):
|
||||
req = tutils.treq()
|
||||
req.form_out = "authority"
|
||||
assert req.pretty_url(True) == "address:22"
|
||||
assert req.pretty_url(False) == "address:22"
|
||||
|
||||
req.form_out = "relative"
|
||||
assert req.pretty_url(True) == "http://address:22/path"
|
||||
assert req.pretty_url(False) == "http://address:22/path"
|
||||
|
||||
def test_get_cookies_none(self):
|
||||
h = odict.ODictCaseless()
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
assert len(r.get_cookies()) == 0
|
||||
|
||||
def test_get_cookies_single(self):
|
||||
h = odict.ODictCaseless()
|
||||
h["Cookie"] = ["cookiename=cookievalue"]
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
result = r.get_cookies()
|
||||
assert len(result) == 1
|
||||
assert result['cookiename'] == ['cookievalue']
|
||||
|
||||
def test_get_cookies_double(self):
|
||||
h = odict.ODictCaseless()
|
||||
h["Cookie"] = [
|
||||
"cookiename=cookievalue;othercookiename=othercookievalue"
|
||||
]
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
result = r.get_cookies()
|
||||
assert len(result) == 2
|
||||
assert result['cookiename'] == ['cookievalue']
|
||||
assert result['othercookiename'] == ['othercookievalue']
|
||||
|
||||
def test_get_cookies_withequalsign(self):
|
||||
h = odict.ODictCaseless()
|
||||
h["Cookie"] = [
|
||||
"cookiename=coo=kievalue;othercookiename=othercookievalue"
|
||||
]
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
result = r.get_cookies()
|
||||
assert len(result) == 2
|
||||
assert result['cookiename'] == ['coo=kievalue']
|
||||
assert result['othercookiename'] == ['othercookievalue']
|
||||
|
||||
def test_set_cookies(self):
|
||||
h = odict.ODictCaseless()
|
||||
h["Cookie"] = ["cookiename=cookievalue"]
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
result = r.get_cookies()
|
||||
result["cookiename"] = ["foo"]
|
||||
r.set_cookies(result)
|
||||
assert r.get_cookies()["cookiename"] == ["foo"]
|
||||
|
||||
def test_set_url(self):
|
||||
r = tutils.treq_absolute()
|
||||
r.url = "https://otheraddress:42/ORLY"
|
||||
assert r.scheme == "https"
|
||||
assert r.host == "otheraddress"
|
||||
assert r.port == 42
|
||||
assert r.path == "/ORLY"
|
||||
|
||||
try:
|
||||
r.url = "//localhost:80/foo@bar"
|
||||
assert False
|
||||
except:
|
||||
assert True
|
||||
|
||||
class TestRequest:
|
||||
# def test_asterisk_form_in(self):
|
||||
# f = tutils.tflow(req=None)
|
||||
# protocol = mock_protocol("OPTIONS * HTTP/1.1")
|
||||
@ -92,105 +349,35 @@ class TestRequest:
|
||||
# "Host: address\r\n"
|
||||
# "Content-Length: 0\r\n\r\n")
|
||||
|
||||
def test_set_url(self):
|
||||
r = tutils.treq_absolute()
|
||||
r.url = "https://otheraddress:42/ORLY"
|
||||
assert r.scheme == "https"
|
||||
assert r.host == "otheraddress"
|
||||
assert r.port == 42
|
||||
assert r.path == "/ORLY"
|
||||
|
||||
def test_repr(self):
|
||||
r = tutils.treq()
|
||||
assert repr(r)
|
||||
|
||||
def test_pretty_host(self):
|
||||
r = tutils.treq()
|
||||
assert r.pretty_host(True) == "address"
|
||||
assert r.pretty_host(False) == "address"
|
||||
r.headers["host"] = ["other"]
|
||||
assert r.pretty_host(True) == "other"
|
||||
assert r.pretty_host(False) == "address"
|
||||
r.host = None
|
||||
assert r.pretty_host(True) == "other"
|
||||
assert r.pretty_host(False) is None
|
||||
del r.headers["host"]
|
||||
assert r.pretty_host(True) is None
|
||||
assert r.pretty_host(False) is None
|
||||
|
||||
# Invalid IDNA
|
||||
r.headers["host"] = [".disqus.com"]
|
||||
assert r.pretty_host(True) == ".disqus.com"
|
||||
|
||||
def test_get_form_for_urlencoded(self):
|
||||
r = tutils.treq()
|
||||
r.headers.add("content-type", "application/x-www-form-urlencoded")
|
||||
r.get_form_urlencoded = MagicMock()
|
||||
|
||||
r.get_form()
|
||||
|
||||
assert r.get_form_urlencoded.called
|
||||
|
||||
def test_get_form_for_multipart(self):
|
||||
r = tutils.treq()
|
||||
r.headers.add("content-type", "multipart/form-data")
|
||||
r.get_form_multipart = MagicMock()
|
||||
|
||||
r.get_form()
|
||||
|
||||
assert r.get_form_multipart.called
|
||||
|
||||
def test_get_cookies_none(self):
|
||||
h = odict.ODictCaseless()
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
assert len(r.get_cookies()) == 0
|
||||
|
||||
def test_get_cookies_single(self):
|
||||
h = odict.ODictCaseless()
|
||||
h["Cookie"] = ["cookiename=cookievalue"]
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
result = r.get_cookies()
|
||||
assert len(result) == 1
|
||||
assert result['cookiename'] == ['cookievalue']
|
||||
|
||||
def test_get_cookies_double(self):
|
||||
h = odict.ODictCaseless()
|
||||
h["Cookie"] = [
|
||||
"cookiename=cookievalue;othercookiename=othercookievalue"
|
||||
]
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
result = r.get_cookies()
|
||||
assert len(result) == 2
|
||||
assert result['cookiename'] == ['cookievalue']
|
||||
assert result['othercookiename'] == ['othercookievalue']
|
||||
|
||||
def test_get_cookies_withequalsign(self):
|
||||
h = odict.ODictCaseless()
|
||||
h["Cookie"] = [
|
||||
"cookiename=coo=kievalue;othercookiename=othercookievalue"
|
||||
]
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
result = r.get_cookies()
|
||||
assert len(result) == 2
|
||||
assert result['cookiename'] == ['coo=kievalue']
|
||||
assert result['othercookiename'] == ['othercookievalue']
|
||||
|
||||
def test_set_cookies(self):
|
||||
h = odict.ODictCaseless()
|
||||
h["Cookie"] = ["cookiename=cookievalue"]
|
||||
r = tutils.treq()
|
||||
r.headers = h
|
||||
result = r.get_cookies()
|
||||
result["cookiename"] = ["foo"]
|
||||
r.set_cookies(result)
|
||||
assert r.get_cookies()["cookiename"] == ["foo"]
|
||||
|
||||
class TestEmptyRequest(object):
|
||||
def test_init(self):
|
||||
req = semantics.EmptyRequest()
|
||||
assert req
|
||||
|
||||
class TestResponse(object):
|
||||
def test_headers_odict(self):
|
||||
tutils.raises(AssertionError, semantics.Response,
|
||||
(1, 1),
|
||||
200,
|
||||
headers='foobar',
|
||||
)
|
||||
|
||||
resp = semantics.Response(
|
||||
(1, 1),
|
||||
200,
|
||||
)
|
||||
assert isinstance(resp.headers, odict.ODictCaseless)
|
||||
|
||||
def test_equal(self):
|
||||
a = tutils.tresp()
|
||||
b = tutils.tresp()
|
||||
assert a == b
|
||||
|
||||
assert not a == 'foo'
|
||||
assert not b == 'foo'
|
||||
assert not 'foo' == a
|
||||
assert not 'foo' == b
|
||||
|
||||
def test_repr(self):
|
||||
r = tutils.tresp()
|
||||
assert "unknown content type" in repr(r)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from netlib import encoding
|
||||
|
||||
|
||||
def test_identity():
|
||||
assert "string" == encoding.decode("identity", "string")
|
||||
assert "string" == encoding.encode("identity", "string")
|
||||
|
@ -44,7 +44,11 @@ def test_client_greeting_assert_socks5():
|
||||
assert False
|
||||
|
||||
raw = tutils.treader("XX")
|
||||
tutils.raises(socks.SocksError, socks.ClientGreeting.from_file, raw, fail_early=True)
|
||||
tutils.raises(
|
||||
socks.SocksError,
|
||||
socks.ClientGreeting.from_file,
|
||||
raw,
|
||||
fail_early=True)
|
||||
|
||||
|
||||
def test_server_greeting():
|
||||
|
@ -1,5 +1,3 @@
|
||||
import urlparse
|
||||
|
||||
from netlib import utils, odict, tutils
|
||||
|
||||
|
||||
@ -30,8 +28,6 @@ def test_pretty_size():
|
||||
assert utils.pretty_size(1024 * 1024) == "1MB"
|
||||
|
||||
|
||||
|
||||
|
||||
def test_parse_url():
|
||||
assert not utils.parse_url("")
|
||||
|
||||
@ -86,7 +82,6 @@ def test_urlencode():
|
||||
assert utils.urlencode([('foo', 'bar')])
|
||||
|
||||
|
||||
|
||||
def test_urldecode():
|
||||
s = "one=two&three=four"
|
||||
assert len(utils.urldecode(s)) == 2
|
||||
@ -101,3 +96,31 @@ def test_get_header_tokens():
|
||||
assert utils.get_header_tokens(h, "foo") == ["bar", "voing"]
|
||||
h["foo"] = ["bar, voing", "oink"]
|
||||
assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"]
|
||||
|
||||
|
||||
def test_multipartdecode():
|
||||
boundary = 'somefancyboundary'
|
||||
headers = odict.ODict(
|
||||
[('content-type', ('multipart/form-data; boundary=%s' % boundary))])
|
||||
content = "--{0}\n" \
|
||||
"Content-Disposition: form-data; name=\"field1\"\n\n" \
|
||||
"value1\n" \
|
||||
"--{0}\n" \
|
||||
"Content-Disposition: form-data; name=\"field2\"\n\n" \
|
||||
"value2\n" \
|
||||
"--{0}--".format(boundary)
|
||||
|
||||
form = utils.multipartdecode(headers, content)
|
||||
|
||||
assert len(form) == 2
|
||||
assert form[0] == ('field1', 'value1')
|
||||
assert form[1] == ('field2', 'value2')
|
||||
|
||||
|
||||
def test_parse_content_type():
|
||||
p = utils.parse_content_type
|
||||
assert p("text/html") == ("text", "html", {})
|
||||
assert p("text") is None
|
||||
|
||||
v = p("text/html; charset=UTF-8")
|
||||
assert v == ('text', 'html', {'charset': 'UTF-8'})
|
||||
|
@ -3,7 +3,8 @@ import threading
|
||||
import Queue
|
||||
import cStringIO
|
||||
import OpenSSL
|
||||
from netlib import tcp, certutils, tutils
|
||||
from netlib import tcp
|
||||
from netlib import tutils
|
||||
|
||||
|
||||
class ServerThread(threading.Thread):
|
||||
|
@ -2,7 +2,10 @@ import os
|
||||
|
||||
from nose.tools import raises
|
||||
|
||||
from netlib import tcp, http, websockets, tutils
|
||||
from netlib import tcp
|
||||
from netlib import tutils
|
||||
from netlib import websockets
|
||||
from netlib.http import status_codes
|
||||
from netlib.http.exceptions import *
|
||||
from netlib.http.http1 import HTTP1Protocol
|
||||
from .. import tservers
|
||||
@ -38,7 +41,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
req = http1_protocol.read_request()
|
||||
key = self.protocol.check_client_handshake(req.headers)
|
||||
|
||||
preamble = http1_protocol.response_preamble(101)
|
||||
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
headers = self.protocol.server_handshake_headers(key)
|
||||
self.wfile.write(headers.format() + "\r\n")
|
||||
@ -62,7 +65,7 @@ class WebSocketsClient(tcp.TCPClient):
|
||||
|
||||
http1_protocol = HTTP1Protocol(self)
|
||||
|
||||
preamble = http1_protocol.request_preamble("GET", "/")
|
||||
preamble = 'GET / HTTP/1.1'
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
headers = self.protocol.client_handshake_headers()
|
||||
self.client_nonce = headers.get_first("sec-websocket-key")
|
||||
@ -162,7 +165,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler):
|
||||
client_hs = http1_protocol.read_request()
|
||||
self.protocol.check_client_handshake(client_hs.headers)
|
||||
|
||||
preamble = http1_protocol.response_preamble(101)
|
||||
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
headers = self.protocol.server_handshake_headers("malformed key")
|
||||
self.wfile.write(headers.format() + "\r\n")
|
||||
|
Loading…
Reference in New Issue
Block a user