refactor request model

This commit is contained in:
Maximilian Hils 2015-09-26 00:39:04 +02:00
parent 45f2ea33b2
commit 106f7046d3
17 changed files with 598 additions and 345 deletions

View File

@ -1,12 +1,15 @@
from __future__ import absolute_import, print_function, division
from .headers import Headers
from .models import Request, Response
from .message import decoded
from .request import Request
from .models import Response
from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2
from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING
from . import http1, http2
__all__ = [
"Headers",
"decoded",
"Request", "Response",
"ALPN_PROTO_HTTP1", "ALPN_PROTO_H2",
"HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING",

View File

@ -27,7 +27,7 @@ else:
_always_byte_args = always_byte_args("utf-8", "surrogateescape")
class Headers(MutableMapping, object):
class Headers(MutableMapping):
"""
Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface.

View File

@ -7,24 +7,24 @@ from .. import CONTENT_MISSING
def assemble_request(request):
if request.body == CONTENT_MISSING:
if request.content == CONTENT_MISSING:
raise HttpException("Cannot assemble flow with CONTENT_MISSING")
head = assemble_request_head(request)
body = b"".join(assemble_body(request.headers, [request.body]))
body = b"".join(assemble_body(request.headers, [request.data.content]))
return head + body
def assemble_request_head(request):
first_line = _assemble_request_line(request)
headers = _assemble_request_headers(request)
first_line = _assemble_request_line(request.data)
headers = _assemble_request_headers(request.data)
return b"%s\r\n%s\r\n" % (first_line, headers)
def assemble_response(response):
if response.body == CONTENT_MISSING:
if response.content == CONTENT_MISSING:
raise HttpException("Cannot assemble flow with CONTENT_MISSING")
head = assemble_response_head(response)
body = b"".join(assemble_body(response.headers, [response.body]))
body = b"".join(assemble_body(response.headers, [response.content]))
return head + body
@ -45,42 +45,49 @@ def assemble_body(headers, body_chunks):
yield chunk
def _assemble_request_line(request, form=None):
if form is None:
form = request.form_out
def _assemble_request_line(request_data):
"""
Args:
request_data (netlib.http.request.RequestData)
"""
form = request_data.first_line_format
if form == "relative":
return b"%s %s %s" % (
request.method,
request.path,
request.http_version
request_data.method,
request_data.path,
request_data.http_version
)
elif form == "authority":
return b"%s %s:%d %s" % (
request.method,
request.host,
request.port,
request.http_version
request_data.method,
request_data.host,
request_data.port,
request_data.http_version
)
elif form == "absolute":
return b"%s %s://%s:%d%s %s" % (
request.method,
request.scheme,
request.host,
request.port,
request.path,
request.http_version
request_data.method,
request_data.scheme,
request_data.host,
request_data.port,
request_data.path,
request_data.http_version
)
else: # pragma: nocover
else:
raise RuntimeError("Invalid request form")
def _assemble_request_headers(request):
headers = request.headers.copy()
if "host" not in headers and request.scheme and request.host and request.port:
def _assemble_request_headers(request_data):
"""
Args:
request_data (netlib.http.request.RequestData)
"""
headers = request_data.headers.copy()
if "host" not in headers and request_data.scheme and request_data.host and request_data.port:
headers["host"] = utils.hostport(
request.scheme,
request.host,
request.port
request_data.scheme,
request_data.host,
request_data.port
)
return bytes(headers)

View File

@ -11,7 +11,7 @@ from .. import Request, Response, Headers
def read_request(rfile, body_size_limit=None):
request = read_request_head(rfile)
expected_body_size = expected_http_body_size(request)
request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
request.timestamp_end = time.time()
return request
@ -155,7 +155,7 @@ def connection_close(http_version, headers):
# If we don't have a Connection header, HTTP 1.1 connections are assumed to
# be persistent
return http_version != b"HTTP/1.1"
return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1" # FIXME: Remove one case.
def expected_http_body_size(request, response=None):
@ -184,11 +184,11 @@ def expected_http_body_size(request, response=None):
if headers.get("expect", "").lower() == "100-continue":
return 0
else:
if request.method.upper() == b"HEAD":
if request.method.upper() == "HEAD":
return 0
if 100 <= response_code <= 199:
return 0
if response_code == 200 and request.method.upper() == b"CONNECT":
if response_code == 200 and request.method.upper() == "CONNECT":
return 0
if response_code in (204, 304):
return 0

146
netlib/http/message.py Normal file
View File

@ -0,0 +1,146 @@
from __future__ import absolute_import, print_function, division
import warnings
import six
from .. import encoding, utils
if six.PY2:
_native = lambda x: x
_always_bytes = lambda x: x
else:
# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
_native = lambda x: x.decode("utf-8", "surrogateescape")
_always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape")
class Message(object):
def __init__(self, data):
self.data = data
def __eq__(self, other):
if isinstance(other, Message):
return self.data == other.data
return False
def __ne__(self, other):
return not self.__eq__(other)
@property
def http_version(self):
"""
Version string, e.g. "HTTP/1.1"
"""
return _native(self.data.http_version)
@http_version.setter
def http_version(self, http_version):
self.data.http_version = _always_bytes(http_version)
@property
def headers(self):
"""
Message headers object
Returns:
netlib.http.Headers
"""
return self.data.headers
@headers.setter
def headers(self, h):
self.data.headers = h
@property
def timestamp_start(self):
"""
First byte timestamp
"""
return self.data.timestamp_start
@timestamp_start.setter
def timestamp_start(self, timestamp_start):
self.data.timestamp_start = timestamp_start
@property
def timestamp_end(self):
"""
Last byte timestamp
"""
return self.data.timestamp_end
@timestamp_end.setter
def timestamp_end(self, timestamp_end):
self.data.timestamp_end = timestamp_end
@property
def content(self):
"""
The raw (encoded) HTTP message body
See also: :py:attr:`text`
"""
return self.data.content
@content.setter
def content(self, content):
self.data.content = content
if isinstance(content, bytes):
self.headers["content-length"] = str(len(content))
@property
def text(self):
"""
The decoded HTTP message body.
Decoded contents are not cached, so this method is relatively expensive to call.
See also: :py:attr:`content`, :py:class:`decoded`
"""
# This attribute should be called text, because that's what requests does.
raise NotImplementedError()
@text.setter
def text(self, text):
raise NotImplementedError()
@property
def body(self):
warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning)
return self.content
@body.setter
def body(self, body):
warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning)
self.content = body
class decoded(object):
"""
A context manager that decodes a request or response, and then
re-encodes it with the same encoding after execution of the block.
Example:
.. code-block:: python
with decoded(request):
request.content = request.content.replace("foo", "bar")
"""
def __init__(self, message):
self.message = message
ce = message.headers.get("content-encoding")
if ce in encoding.ENCODINGS:
self.ce = ce
else:
self.ce = None
def __enter__(self):
if self.ce:
if not self.message.decode():
self.ce = None
def __exit__(self, type, value, tb):
if self.ce:
self.message.encode(self.ce)

View File

@ -47,239 +47,6 @@ class Message(object):
return False
class Request(Message):
def __init__(
self,
form_in,
method,
scheme,
host,
port,
path,
http_version,
headers=None,
body=None,
timestamp_start=None,
timestamp_end=None,
form_out=None
):
super(Request, self).__init__(http_version, headers, body, timestamp_start, timestamp_end)
self.form_in = form_in
self.method = method
self.scheme = scheme
self.host = host
self.port = port
self.path = path
self.form_out = form_out or form_in
def __repr__(self):
if self.host and self.port:
hostport = "{}:{}".format(native(self.host,"idna"), self.port)
else:
hostport = ""
path = self.path or ""
return "HTTPRequest({} {}{})".format(
self.method, hostport, path
)
def anticache(self):
"""
Modifies this request to remove headers that might produce a cached
response. That is, we remove ETags and If-Modified-Since headers.
"""
delheaders = [
"if-modified-since",
"if-none-match",
]
for i in delheaders:
self.headers.pop(i, None)
def anticomp(self):
"""
Modifies this request to remove headers that will compress the
resource's data.
"""
self.headers["accept-encoding"] = "identity"
def constrain_encoding(self):
"""
Limits the permissible Accept-Encoding values, based on what we can
decode appropriately.
"""
accept_encoding = self.headers.get("accept-encoding")
if accept_encoding:
self.headers["accept-encoding"] = (
', '.join(
e
for e in encoding.ENCODINGS
if e in accept_encoding
)
)
def update_host_header(self):
"""
Update the host header to reflect the current target.
"""
self.headers["host"] = self.host
def get_form(self):
"""
Retrieves the URL-encoded or multipart form data, returning an ODict object.
Returns an empty ODict if there is no data or the content-type
indicates non-form data.
"""
if self.body:
if HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower():
return self.get_form_urlencoded()
elif HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower():
return self.get_form_multipart()
return ODict([])
def get_form_urlencoded(self):
"""
Retrieves the URL-encoded form data, returning an ODict object.
Returns an empty ODict if there is no data or the content-type
indicates non-form data.
"""
if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower():
return ODict(utils.urldecode(self.body))
return ODict([])
def get_form_multipart(self):
if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower():
return ODict(
utils.multipartdecode(
self.headers,
self.body))
return ODict([])
def set_form_urlencoded(self, odict):
"""
Sets the body to the URL-encoded form data, and adds the
appropriate content-type header. Note that this will destory the
existing body if there is one.
"""
# FIXME: If there's an existing content-type header indicating a
# url-encoded form, leave it alone.
self.headers["content-type"] = HDR_FORM_URLENCODED
self.body = utils.urlencode(odict.lst)
def get_path_components(self):
"""
Returns the path components of the URL as a list of strings.
Components are unquoted.
"""
_, _, path, _, _, _ = urllib.parse.urlparse(self.url)
return [urllib.parse.unquote(native(i,"ascii")) for i in path.split(b"/") if i]
def set_path_components(self, lst):
"""
Takes a list of strings, and sets the path component of the URL.
Components are quoted.
"""
lst = [urllib.parse.quote(i, safe="") for i in lst]
path = always_bytes("/" + "/".join(lst))
scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
self.url = urllib.parse.urlunparse(
[scheme, netloc, path, params, query, fragment]
)
def get_query(self):
"""
Gets the request query string. Returns an ODict object.
"""
_, _, _, _, query, _ = urllib.parse.urlparse(self.url)
if query:
return ODict(utils.urldecode(query))
return ODict([])
def set_query(self, odict):
"""
Takes an ODict object, and sets the request query string.
"""
scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
query = utils.urlencode(odict.lst)
self.url = urllib.parse.urlunparse(
[scheme, netloc, path, params, query, fragment]
)
def pretty_host(self, hostheader):
"""
Heuristic to get the host of the request.
Note that pretty_host() does not always return the TCP destination
of the request, e.g. if an upstream proxy is in place
If hostheader is set to True, the Host: header will be used as
additional (and preferred) data source. This is handy in
transparent mode, where only the IO of the destination is known,
but not the resolved name. This is disabled by default, as an
attacker may spoof the host header to confuse an analyst.
"""
if hostheader and "host" in self.headers:
try:
return self.headers["host"]
except ValueError:
pass
if self.host:
return self.host.decode("idna")
def pretty_url(self, hostheader):
if self.form_out == "authority": # upstream proxy mode
return b"%s:%d" % (always_bytes(self.pretty_host(hostheader)), self.port)
return utils.unparse_url(self.scheme,
self.pretty_host(hostheader),
self.port,
self.path)
def get_cookies(self):
"""
Returns a possibly empty netlib.odict.ODict object.
"""
ret = ODict()
for i in self.headers.get_all("Cookie"):
ret.extend(cookies.parse_cookie_header(i))
return ret
def set_cookies(self, odict):
"""
Takes an netlib.odict.ODict object. Over-writes any existing Cookie
headers.
"""
v = cookies.format_cookie_header(odict)
self.headers["cookie"] = v
@property
def url(self):
"""
Returns a URL string, constructed from the Request's URL components.
"""
return utils.unparse_url(
self.scheme,
self.host,
self.port,
self.path
)
@url.setter
def url(self, url):
"""
Parses a URL specification, and updates the Request's information
accordingly.
Raises:
ValueError if the URL was invalid
"""
# TODO: Should handle incoming unicode here.
parts = utils.parse_url(url)
if not parts:
raise ValueError("Invalid URL: %s" % url)
self.scheme, self.host, self.port, self.path = parts
class Response(Message):
def __init__(
self,

351
netlib/http/request.py Normal file
View File

@ -0,0 +1,351 @@
from __future__ import absolute_import, print_function, division
import warnings
import six
from six.moves import urllib
from netlib import utils
from netlib.http import cookies
from netlib.odict import ODict
from .. import encoding
from .headers import Headers
from .message import Message, _native, _always_bytes
class RequestData(object):
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None,
timestamp_start=None, timestamp_end=None):
if not headers:
headers = Headers()
assert isinstance(headers, Headers)
self.first_line_format = first_line_format
self.method = method
self.scheme = scheme
self.host = host
self.port = port
self.path = path
self.http_version = http_version
self.headers = headers
self.content = content
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
def __eq__(self, other):
if isinstance(other, RequestData):
return self.__dict__ == other.__dict__
return False
def __ne__(self, other):
return not self.__eq__(other)
class Request(Message):
"""
An HTTP request.
"""
def __init__(self, *args, **kwargs):
data = RequestData(*args, **kwargs)
super(Request, self).__init__(data)
def __repr__(self):
if self.host and self.port:
hostport = "{}:{}".format(self.host, self.port)
else:
hostport = ""
path = self.path or ""
return "HTTPRequest({} {}{})".format(
self.method, hostport, path
)
@property
def first_line_format(self):
"""
HTTP request form as defined in `RFC7230 <https://tools.ietf.org/html/rfc7230#section-5.3>`_.
origin-form and asterisk-form are subsumed as "relative".
"""
return self.data.first_line_format
@first_line_format.setter
def first_line_format(self, first_line_format):
self.data.first_line_format = first_line_format
@property
def method(self):
"""
HTTP request method, e.g. "GET".
"""
return _native(self.data.method)
@method.setter
def method(self, method):
self.data.method = _always_bytes(method)
@property
def scheme(self):
"""
HTTP request scheme, which should be "http" or "https".
"""
return _native(self.data.scheme)
@scheme.setter
def scheme(self, scheme):
self.data.scheme = _always_bytes(scheme)
@property
def host(self):
"""
Target host for the request. This may be directly taken in the request (e.g. "GET http://example.com/ HTTP/1.1")
or inferred from the proxy mode (e.g. an IP in transparent mode).
"""
if six.PY2:
return self.data.host
if not self.data.host:
return self.data.host
try:
return self.data.host.decode("idna")
except UnicodeError:
return self.data.host.decode("utf8", "surrogateescape")
@host.setter
def host(self, host):
if isinstance(host, six.text_type):
try:
# There's no non-strict mode for IDNA encoding.
# We don't want this operation to fail though, so we try
# utf8 as a last resort.
host = host.encode("idna", "strict")
except UnicodeError:
host = host.encode("utf8", "surrogateescape")
self.data.host = host
# Update host header
if "host" in self.headers:
if host:
self.headers["host"] = host
else:
self.headers.pop("host")
@property
def port(self):
"""
Target port
"""
return self.data.port
@port.setter
def port(self, port):
self.data.port = port
@property
def path(self):
"""
HTTP request path, e.g. "/index.html".
Guaranteed to start with a slash.
"""
return _native(self.data.path)
@path.setter
def path(self, path):
self.data.path = _always_bytes(path)
def anticache(self):
"""
Modifies this request to remove headers that might produce a cached
response. That is, we remove ETags and If-Modified-Since headers.
"""
delheaders = [
"if-modified-since",
"if-none-match",
]
for i in delheaders:
self.headers.pop(i, None)
def anticomp(self):
"""
Modifies this request to remove headers that will compress the
resource's data.
"""
self.headers["accept-encoding"] = "identity"
def constrain_encoding(self):
"""
Limits the permissible Accept-Encoding values, based on what we can
decode appropriately.
"""
accept_encoding = self.headers.get("accept-encoding")
if accept_encoding:
self.headers["accept-encoding"] = (
', '.join(
e
for e in encoding.ENCODINGS
if e in accept_encoding
)
)
@property
def urlencoded_form(self):
"""
The URL-encoded form data as an ODict object.
None if there is no data or the content-type indicates non-form data.
"""
is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
if self.content and is_valid_content_type:
return ODict(utils.urldecode(self.content))
return None
@urlencoded_form.setter
def urlencoded_form(self, odict):
"""
Sets the body to the URL-encoded form data, and adds the appropriate content-type header.
This will overwrite the existing content if there is one.
"""
self.headers["content-type"] = "application/x-www-form-urlencoded"
self.content = utils.urlencode(odict.lst)
@property
def multipart_form(self):
"""
The multipart form data as an ODict object.
None if there is no data or the content-type indicates non-form data.
"""
is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
if self.content and is_valid_content_type:
return ODict(utils.multipartdecode(self.headers,self.content))
return None
@multipart_form.setter
def multipart_form(self):
raise NotImplementedError()
@property
def path_components(self):
"""
The URL's path components as a list of strings.
Components are unquoted.
"""
_, _, path, _, _, _ = urllib.parse.urlparse(self.url)
return [urllib.parse.unquote(i) for i in path.split("/") if i]
@path_components.setter
def path_components(self, components):
components = map(lambda x: urllib.parse.quote(x, safe=""), components)
path = "/" + "/".join(components)
scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])
@property
def query(self):
"""
The request query string as an ODict object.
None, if there is no query.
"""
_, _, _, _, query, _ = urllib.parse.urlparse(self.url)
if query:
return ODict(utils.urldecode(query))
return None
@query.setter
def query(self, odict):
query = utils.urlencode(odict.lst)
scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])
@property
def cookies(self):
"""
The request cookies.
An empty ODict object if the cookie monster ate them all.
"""
ret = ODict()
for i in self.headers.get_all("Cookie"):
ret.extend(cookies.parse_cookie_header(i))
return ret
@cookies.setter
def cookies(self, odict):
self.headers["cookie"] = cookies.format_cookie_header(odict)
@property
def url(self):
"""
The URL string, constructed from the request's URL components
"""
return utils.unparse_url(self.scheme, self.host, self.port, self.path)
@url.setter
def url(self, url):
self.scheme, self.host, self.port, self.path = utils.parse_url(url)
@property
def pretty_host(self):
return self.headers.get("host", self.host)
@property
def pretty_url(self):
if self.first_line_format == "authority":
return "%s:%d" % (self.pretty_host, self.port)
return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path)
# Legacy
def get_cookies(self):
warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning)
return self.cookies
def set_cookies(self, odict):
warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning)
self.cookies = odict
def get_query(self):
warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning)
return self.query or ODict([])
def set_query(self, odict):
warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning)
self.query = odict
def get_path_components(self):
warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning)
return self.path_components
def set_path_components(self, lst):
warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning)
self.path_components = lst
def get_form_urlencoded(self):
warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning)
return self.urlencoded_form or ODict([])
def set_form_urlencoded(self, odict):
warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning)
self.urlencoded_form = odict
def get_form_multipart(self):
warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning)
return self.multipart_form or ODict([])
@property
def form_in(self):
warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning)
return self.first_line_format
@form_in.setter
def form_in(self, form_in):
warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning)
self.first_line_format = form_in
@property
def form_out(self):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
return self.first_line_format
@form_out.setter
def form_out(self, form_out):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
self.first_line_format = form_out

3
netlib/http/response.py Normal file
View File

@ -0,0 +1,3 @@
from __future__ import absolute_import, print_function, division
# TODO

View File

@ -98,7 +98,7 @@ def treq(**kwargs):
netlib.http.Request
"""
default = dict(
form_in="relative",
first_line_format="relative",
method=b"GET",
scheme=b"http",
host=b"address",
@ -106,7 +106,7 @@ def treq(**kwargs):
path=b"/path",
http_version=b"HTTP/1.1",
headers=Headers(header="qvalue"),
body=b"content"
content=b"content"
)
default.update(kwargs)
return Request(**default)

View File

@ -273,22 +273,27 @@ def get_header_tokens(headers, key):
return [token.strip() for token in tokens]
@always_byte_args()
def hostport(scheme, host, port):
"""
Returns the host component, with a port specifcation if needed.
"""
if (port, scheme) in [(80, b"http"), (443, b"https")]:
if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]:
return host
else:
if isinstance(host, six.binary_type):
return b"%s:%d" % (host, port)
else:
return "%s:%d" % (host, port)
def unparse_url(scheme, host, port, path=""):
"""
Returns a URL string, constructed from the specified compnents.
Returns a URL string, constructed from the specified components.
Args:
All args must be str.
"""
return b"%s://%s%s" % (scheme, hostport(scheme, host, port), path)
return "%s://%s%s" % (scheme, hostport(scheme, host, port), path)
def urlencode(s):

View File

@ -20,7 +20,7 @@ def test_assemble_request():
)
with raises(HttpException):
assemble_request(treq(body=CONTENT_MISSING))
assemble_request(treq(content=CONTENT_MISSING))
def test_assemble_request_head():
@ -62,21 +62,21 @@ def test_assemble_body():
def test_assemble_request_line():
assert _assemble_request_line(treq()) == b"GET /path HTTP/1.1"
assert _assemble_request_line(treq().data) == b"GET /path HTTP/1.1"
authority_request = treq(method=b"CONNECT", form_in="authority")
authority_request = treq(method=b"CONNECT", first_line_format="authority").data
assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1"
absolute_request = treq(form_in="absolute")
absolute_request = treq(first_line_format="absolute").data
assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1"
with raises(RuntimeError):
_assemble_request_line(treq(), "invalid_form")
_assemble_request_line(treq(first_line_format="invalid_form").data)
def test_assemble_request_headers():
# https://github.com/mitmproxy/mitmproxy/issues/186
r = treq(body=b"")
r = treq(content=b"")
r.headers["Transfer-Encoding"] = "chunked"
c = _assemble_request_headers(r)
assert b"Transfer-Encoding" in c

View File

@ -16,8 +16,8 @@ from netlib.tutils import treq, tresp, raises
def test_read_request():
rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip")
r = read_request(rfile)
assert r.method == b"GET"
assert r.body == b""
assert r.method == "GET"
assert r.content == b""
assert r.timestamp_end
assert rfile.read() == b"skip"
@ -32,7 +32,7 @@ def test_read_request_head():
rfile.reset_timestamps = Mock()
rfile.first_byte_timestamp = 42
r = read_request_head(rfile)
assert r.method == b"GET"
assert r.method == "GET"
assert r.headers["Content-Length"] == "4"
assert r.body is None
assert rfile.reset_timestamps.called
@ -283,7 +283,7 @@ class TestReadHeaders(object):
def test_read_chunked():
req = treq(body=None)
req = treq(content=None)
req.headers["Transfer-Encoding"] = "chunked"
data = b"1\r\na\r\n0\r\n"

View File

@ -39,6 +39,7 @@ class TestRequest(object):
a = tutils.treq(timestamp_start=42, timestamp_end=43)
b = tutils.treq(timestamp_start=42, timestamp_end=43)
assert a == b
assert not a != b
assert not a == 'foo'
assert not b == 'foo'
@ -70,45 +71,17 @@ class TestRequest(object):
req = tutils.treq()
req.headers["Host"] = ""
req.host = "foobar"
req.update_host_header()
assert req.headers["Host"] == "foobar"
def test_get_form(self):
req = tutils.treq()
assert req.get_form() == ODict()
@mock.patch("netlib.http.Request.get_form_multipart")
@mock.patch("netlib.http.Request.get_form_urlencoded")
def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart):
req = tutils.treq()
assert req.get_form() == ODict()
req = tutils.treq()
req.body = "foobar"
req.headers["Content-Type"] = HDR_FORM_URLENCODED
req.get_form()
assert req.get_form_urlencoded.called
assert not req.get_form_multipart.called
@mock.patch("netlib.http.Request.get_form_multipart")
@mock.patch("netlib.http.Request.get_form_urlencoded")
def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart):
req = tutils.treq()
req.body = "foobar"
req.headers["Content-Type"] = HDR_FORM_MULTIPART
req.get_form()
assert not req.get_form_urlencoded.called
assert req.get_form_multipart.called
def test_get_form_urlencoded(self):
req = tutils.treq(body="foobar")
req = tutils.treq(content="foobar")
assert req.get_form_urlencoded() == ODict()
req.headers["Content-Type"] = HDR_FORM_URLENCODED
assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body))
def test_get_form_multipart(self):
req = tutils.treq(body="foobar")
req = tutils.treq(content="foobar")
assert req.get_form_multipart() == ODict()
req.headers["Content-Type"] = HDR_FORM_MULTIPART
@ -140,7 +113,7 @@ class TestRequest(object):
assert req.get_query().lst == []
req.url = "http://localhost:80/foo?bar=42"
assert req.get_query().lst == [(b"bar", b"42")]
assert req.get_query().lst == [("bar", "42")]
def test_set_query(self):
req = tutils.treq()
@ -148,31 +121,23 @@ class TestRequest(object):
def test_pretty_host(self):
r = tutils.treq()
assert r.pretty_host(True) == "address"
assert r.pretty_host(False) == "address"
assert r.pretty_host == "address"
assert r.host == "address"
r.headers["host"] = "other"
assert r.pretty_host(True) == "other"
assert r.pretty_host(False) == "address"
assert r.pretty_host == "other"
assert r.host == "address"
r.host = None
assert r.pretty_host(True) == "other"
assert r.pretty_host(False) is None
del r.headers["host"]
assert r.pretty_host(True) is None
assert r.pretty_host(False) is None
assert r.pretty_host is None
assert r.host is None
# Invalid IDNA
r.headers["host"] = ".disqus.com"
assert r.pretty_host(True) == ".disqus.com"
assert r.pretty_host == ".disqus.com"
def test_pretty_url(self):
req = tutils.treq()
req.form_out = "authority"
assert req.pretty_url(True) == b"address:22"
assert req.pretty_url(False) == b"address:22"
req.form_out = "relative"
assert req.pretty_url(True) == b"http://address:22/path"
assert req.pretty_url(False) == b"http://address:22/path"
req = tutils.treq(first_line_format="relative")
assert req.pretty_url == "http://address:22/path"
assert req.url == "http://address:22/path"
def test_get_cookies_none(self):
headers = Headers()
@ -212,12 +177,12 @@ class TestRequest(object):
assert r.get_cookies()["cookiename"] == ["foo"]
def test_set_url(self):
r = tutils.treq(form_in="absolute")
r = tutils.treq(first_line_format="absolute")
r.url = b"https://otheraddress:42/ORLY"
assert r.scheme == b"https"
assert r.host == b"otheraddress"
assert r.scheme == "https"
assert r.host == "otheraddress"
assert r.port == 42
assert r.path == b"/ORLY"
assert r.path == "/ORLY"
try:
r.url = "//localhost:80/foo@bar"
@ -230,7 +195,7 @@ class TestRequest(object):
# protocol = mock_protocol("OPTIONS * HTTP/1.1")
# f.request = HTTPRequest.from_protocol(protocol)
#
# assert f.request.form_in == "relative"
# assert f.request.first_line_format == "relative"
# f.request.host = f.server_conn.address.host
# f.request.port = f.server_conn.address.port
# f.request.scheme = "http"
@ -266,7 +231,7 @@ class TestRequest(object):
# "CONNECT address:22 HTTP/1.1\r\n"
# "Host: address:22\r\n"
# "Content-Length: 0\r\n\r\n")
# assert r.pretty_url(False) == "address:22"
# assert r.pretty_url == "address:22"
#
# def test_absolute_form_in(self):
# protocol = mock_protocol("GET oops-no-protocol.com HTTP/1.1")

View File

@ -0,0 +1,3 @@
from __future__ import absolute_import, print_function, division
# TODO

View File

@ -0,0 +1,3 @@
from __future__ import absolute_import, print_function, division
# TODO

View File

@ -84,10 +84,10 @@ def test_parse_url():
def test_unparse_url():
assert utils.unparse_url(b"http", b"foo.com", 99, b"") == b"http://foo.com:99"
assert utils.unparse_url(b"http", b"foo.com", 80, b"/bar") == b"http://foo.com/bar"
assert utils.unparse_url(b"https", b"foo.com", 80, b"") == b"https://foo.com:80"
assert utils.unparse_url(b"https", b"foo.com", 443, b"") == b"https://foo.com"
assert utils.unparse_url("http", "foo.com", 99, "") == "http://foo.com:99"
assert utils.unparse_url("http", "foo.com", 80, "/bar") == "http://foo.com/bar"
assert utils.unparse_url("https", "foo.com", 80, "") == "https://foo.com:80"
assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com"
def test_urlencode():

View File

@ -68,7 +68,7 @@ class WebSocketsClient(tcp.TCPClient):
self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
resp = read_response(self.rfile, treq(method="GET"))
resp = read_response(self.rfile, treq(method=b"GET"))
server_nonce = self.protocol.check_server_handshake(resp.headers)
if not server_nonce == self.protocol.create_server_nonce(self.client_nonce):