Merge branch 'mhils-multidict'

This commit is contained in:
Aldo Cortesi 2016-05-21 15:01:19 +12:00
commit 97f3077082
31 changed files with 884 additions and 504 deletions

View File

@ -56,6 +56,17 @@ Datastructures
:special-members: :special-members:
:no-undoc-members: :no-undoc-members:
.. autoclass:: MultiDictView
.. automethod:: get_all
.. automethod:: set_all
.. automethod:: add
.. automethod:: insert
.. automethod:: keys
.. automethod:: values
.. automethod:: items
.. automethod:: to_dict
.. autoclass:: decoded .. autoclass:: decoded
.. automodule:: mitmproxy.models .. automodule:: mitmproxy.models

View File

@ -1,5 +1,8 @@
def request(context, flow): def request(context, flow):
form = flow.request.urlencoded_form if flow.request.urlencoded_form:
if form is not None: flow.request.urlencoded_form["mitmproxy"] = "rocks"
form["mitmproxy"] = ["rocks"] else:
flow.request.urlencoded_form = form # This sets the proper content type and overrides the body.
flow.request.urlencoded_form = [
("foo", "bar")
]

View File

@ -1,5 +1,2 @@
def request(context, flow): def request(context, flow):
q = flow.request.query flow.request.query["mitmproxy"] = "rocks"
if q:
q["mitmproxy"] = ["rocks"]
flow.request.query = q

View File

@ -6,8 +6,7 @@ import sys
import math import math
import urwid import urwid
from netlib import odict from netlib.http import Headers, status_codes
from netlib.http import Headers
from . import common, grideditor, signals, searchable, tabs from . import common, grideditor, signals, searchable, tabs
from . import flowdetailview from . import flowdetailview
from .. import utils, controller, contentviews from .. import utils, controller, contentviews
@ -316,21 +315,18 @@ class FlowView(tabs.Tabs):
return "Invalid URL." return "Invalid URL."
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)
def set_resp_code(self, code): def set_resp_status_code(self, status_code):
response = self.flow.response
try: try:
response.status_code = int(code) status_code = int(status_code)
except ValueError: except ValueError:
return None return None
import BaseHTTPServer self.flow.response.status_code = status_code
if int(code) in BaseHTTPServer.BaseHTTPRequestHandler.responses: if status_code in status_codes.RESPONSES:
response.msg = BaseHTTPServer.BaseHTTPRequestHandler.responses[ self.flow.response.reason = status_codes.RESPONSES[status_code]
int(code)][0]
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)
def set_resp_msg(self, msg): def set_resp_reason(self, reason):
response = self.flow.response self.flow.response.reason = reason
response.msg = msg
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)
def set_headers(self, fields, conn): def set_headers(self, fields, conn):
@ -338,22 +334,22 @@ class FlowView(tabs.Tabs):
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)
def set_query(self, lst, conn): def set_query(self, lst, conn):
conn.set_query(odict.ODict(lst)) conn.query = lst
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)
def set_path_components(self, lst, conn): def set_path_components(self, lst, conn):
conn.set_path_components(lst) conn.path_components = lst
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)
def set_form(self, lst, conn): def set_form(self, lst, conn):
conn.set_form_urlencoded(odict.ODict(lst)) conn.urlencoded_form = lst
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)
def edit_form(self, conn): def edit_form(self, conn):
self.master.view_grideditor( self.master.view_grideditor(
grideditor.URLEncodedFormEditor( grideditor.URLEncodedFormEditor(
self.master, self.master,
conn.get_form_urlencoded().lst, conn.urlencoded_form.items(multi=True),
self.set_form, self.set_form,
conn conn
) )
@ -364,7 +360,7 @@ class FlowView(tabs.Tabs):
self.edit_form(conn) self.edit_form(conn)
def set_cookies(self, lst, conn): def set_cookies(self, lst, conn):
conn.cookies = odict.ODict(lst) conn.cookies = lst
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)
def set_setcookies(self, data, conn): def set_setcookies(self, data, conn):
@ -388,7 +384,7 @@ class FlowView(tabs.Tabs):
self.master.view_grideditor( self.master.view_grideditor(
grideditor.CookieEditor( grideditor.CookieEditor(
self.master, self.master,
message.cookies.lst, message.cookies.items(multi=True),
self.set_cookies, self.set_cookies,
message message
) )
@ -397,7 +393,7 @@ class FlowView(tabs.Tabs):
self.master.view_grideditor( self.master.view_grideditor(
grideditor.SetCookieEditor( grideditor.SetCookieEditor(
self.master, self.master,
message.cookies, message.cookies.items(multi=True),
self.set_setcookies, self.set_setcookies,
message message
) )
@ -413,7 +409,7 @@ class FlowView(tabs.Tabs):
c = self.master.spawn_editor(message.content or "") c = self.master.spawn_editor(message.content or "")
message.content = c.rstrip("\n") message.content = c.rstrip("\n")
elif part == "f": elif part == "f":
if not message.get_form_urlencoded() and message.content: if not message.urlencoded_form and message.content:
signals.status_prompt_onekey.send( signals.status_prompt_onekey.send(
prompt = "Existing body is not a URL-encoded form. Clear and edit?", prompt = "Existing body is not a URL-encoded form. Clear and edit?",
keys = [ keys = [
@ -435,7 +431,7 @@ class FlowView(tabs.Tabs):
) )
) )
elif part == "p": elif part == "p":
p = message.get_path_components() p = message.path_components
self.master.view_grideditor( self.master.view_grideditor(
grideditor.PathEditor( grideditor.PathEditor(
self.master, self.master,
@ -448,7 +444,7 @@ class FlowView(tabs.Tabs):
self.master.view_grideditor( self.master.view_grideditor(
grideditor.QueryEditor( grideditor.QueryEditor(
self.master, self.master,
message.get_query().lst, message.query.items(multi=True),
self.set_query, message self.set_query, message
) )
) )
@ -458,7 +454,7 @@ class FlowView(tabs.Tabs):
text = message.url, text = message.url,
callback = self.set_url callback = self.set_url
) )
elif part == "m": elif part == "m" and message == self.flow.request:
signals.status_prompt_onekey.send( signals.status_prompt_onekey.send(
prompt = "Method", prompt = "Method",
keys = common.METHOD_OPTIONS, keys = common.METHOD_OPTIONS,
@ -468,13 +464,13 @@ class FlowView(tabs.Tabs):
signals.status_prompt.send( signals.status_prompt.send(
prompt = "Code", prompt = "Code",
text = str(message.status_code), text = str(message.status_code),
callback = self.set_resp_code callback = self.set_resp_status_code
) )
elif part == "m": elif part == "m" and message == self.flow.response:
signals.status_prompt.send( signals.status_prompt.send(
prompt = "Message", prompt = "Message",
text = message.msg, text = message.reason,
callback = self.set_resp_msg callback = self.set_resp_reason
) )
signals.flow_change.send(self, flow = self.flow) signals.flow_change.send(self, flow = self.flow)

View File

@ -700,17 +700,17 @@ class SetCookieEditor(GridEditor):
def data_in(self, data): def data_in(self, data):
flattened = [] flattened = []
for k, v in data.items(): for key, (value, attrs) in data:
flattened.append([k, v[0], v[1].lst]) flattened.append([key, value, attrs.items(multi=True)])
return flattened return flattened
def data_out(self, data): def data_out(self, data):
vals = [] vals = []
for i in data: for key, value, attrs in data:
vals.append( vals.append(
[ [
i[0], key,
[i[1], odict.ODictCaseless(i[2])] (value, attrs)
] ]
) )
return odict.ODict(vals) return vals

View File

@ -158,9 +158,9 @@ class SetHeaders:
for _, header, value, cpatt in self.lst: for _, header, value, cpatt in self.lst:
if cpatt(f): if cpatt(f):
if f.response: if f.response:
f.response.headers.fields.append((header, value)) f.response.headers.add(header, value)
else: else:
f.request.headers.fields.append((header, value)) f.request.headers.add(header, value)
class StreamLargeBodies(object): class StreamLargeBodies(object):
@ -265,7 +265,7 @@ class ServerPlaybackState:
form_contents = r.urlencoded_form or r.multipart_form form_contents = r.urlencoded_form or r.multipart_form
if self.ignore_payload_params and form_contents: if self.ignore_payload_params and form_contents:
key.extend( key.extend(
p for p in form_contents p for p in form_contents.items(multi=True)
if p[0] not in self.ignore_payload_params if p[0] not in self.ignore_payload_params
) )
else: else:
@ -321,10 +321,10 @@ class StickyCookieState:
""" """
domain = f.request.host domain = f.request.host
path = "/" path = "/"
if attrs["domain"]: if "domain" in attrs:
domain = attrs["domain"][-1] domain = attrs["domain"]
if attrs["path"]: if "path" in attrs:
path = attrs["path"][-1] path = attrs["path"]
return (domain, f.request.port, path) return (domain, f.request.port, path)
def domain_match(self, a, b): def domain_match(self, a, b):
@ -335,28 +335,26 @@ class StickyCookieState:
return False return False
def handle_response(self, f): def handle_response(self, f):
for i in f.response.headers.get_all("set-cookie"): for name, (value, attrs) in f.response.cookies.items(multi=True):
# FIXME: We now know that Cookie.py screws up some cookies with # FIXME: We now know that Cookie.py screws up some cookies with
# valid RFC 822/1123 datetime specifications for expiry. Sigh. # valid RFC 822/1123 datetime specifications for expiry. Sigh.
name, value, attrs = cookies.parse_set_cookie_header(str(i))
a = self.ckey(attrs, f) a = self.ckey(attrs, f)
if self.domain_match(f.request.host, a[0]): if self.domain_match(f.request.host, a[0]):
b = attrs.lst b = attrs.with_insert(0, name, value)
b.insert(0, [name, value]) self.jar[a][name] = b
self.jar[a][name] = odict.ODictCaseless(b)
def handle_request(self, f): def handle_request(self, f):
l = [] l = []
if f.match(self.flt): if f.match(self.flt):
for i in self.jar.keys(): for domain, port, path in self.jar.keys():
match = [ match = [
self.domain_match(f.request.host, i[0]), self.domain_match(f.request.host, domain),
f.request.port == i[1], f.request.port == port,
f.request.path.startswith(i[2]) f.request.path.startswith(path)
] ]
if all(match): if all(match):
c = self.jar[i] c = self.jar[(domain, port, path)]
l.extend([cookies.format_cookie_header(c[name]) for name in c.keys()]) l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()])
if l: if l:
f.request.stickycookie = True f.request.stickycookie = True
f.request.headers["cookie"] = "; ".join(l) f.request.headers["cookie"] = "; ".join(l)

View File

@ -51,7 +51,7 @@ def python_code(flow):
params = "" params = ""
if flow.request.query: if flow.request.query:
lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query] lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()]
params = "\nparams = {\n%s}\n" % "".join(lines) params = "\nparams = {\n%s}\n" % "".join(lines)
args += "\n params=params," args += "\n params=params,"
@ -140,7 +140,7 @@ def locust_code(flow):
params = "" params = ""
if flow.request.query: if flow.request.query:
lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query] lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()]
params = "\n params = {\n%s }\n" % "".join(lines) params = "\n params = {\n%s }\n" % "".join(lines)
args += "\n params=params," args += "\n params=params,"

View File

@ -5,7 +5,6 @@ from __future__ import absolute_import
from io import BytesIO from io import BytesIO
import gzip import gzip
import zlib import zlib
from .utils import always_byte_args
ENCODINGS = {"identity", "gzip", "deflate"} ENCODINGS = {"identity", "gzip", "deflate"}

View File

@ -3,12 +3,12 @@ from .request import Request
from .response import Response from .response import Response
from .headers import Headers from .headers import Headers
from .message import decoded from .message import decoded
from . import http1, http2 from . import http1, http2, status_codes
__all__ = [ __all__ = [
"Request", "Request",
"Response", "Response",
"Headers", "Headers",
"decoded", "decoded",
"http1", "http2", "http1", "http2", "status_codes",
] ]

View File

@ -1,8 +1,8 @@
from six.moves import http_cookies as Cookie import collections
import re import re
import string
from email.utils import parsedate_tz, formatdate, mktime_tz from email.utils import parsedate_tz, formatdate, mktime_tz
from netlib.multidict import ImmutableMultiDict
from .. import odict from .. import odict
""" """
@ -157,42 +157,76 @@ def _parse_set_cookie_pairs(s):
return pairs return pairs
def parse_set_cookie_headers(headers):
ret = []
for header in headers:
v = parse_set_cookie_header(header)
if v:
name, value, attrs = v
ret.append((name, SetCookie(value, attrs)))
return ret
class CookieAttrs(ImmutableMultiDict):
@staticmethod
def _kconv(key):
return key.lower()
@staticmethod
def _reduce_values(values):
# See the StickyCookieTest for a weird cookie that only makes sense
# if we take the last part.
return values[-1]
SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"])
def parse_set_cookie_header(line): def parse_set_cookie_header(line):
""" """
Parse a Set-Cookie header value Parse a Set-Cookie header value
Returns a (name, value, attrs) tuple, or None, where attrs is an Returns a (name, value, attrs) tuple, or None, where attrs is an
ODictCaseless set of attributes. No attempt is made to parse attribute CookieAttrs dict of attributes. No attempt is made to parse attribute
values - they are treated purely as strings. values - they are treated purely as strings.
""" """
pairs = _parse_set_cookie_pairs(line) pairs = _parse_set_cookie_pairs(line)
if pairs: if pairs:
return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) return pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:])
def format_set_cookie_header(name, value, attrs): def format_set_cookie_header(name, value, attrs):
""" """
Formats a Set-Cookie header value. Formats a Set-Cookie header value.
""" """
pairs = [[name, value]] pairs = [(name, value)]
pairs.extend(attrs.lst) pairs.extend(
attrs.fields if hasattr(attrs, "fields") else attrs
)
return _format_set_cookie_pairs(pairs) return _format_set_cookie_pairs(pairs)
def parse_cookie_headers(cookie_headers):
cookie_list = []
for header in cookie_headers:
cookie_list.extend(parse_cookie_header(header))
return cookie_list
def parse_cookie_header(line): def parse_cookie_header(line):
""" """
Parse a Cookie header value. Parse a Cookie header value.
Returns a (possibly empty) ODict object. Returns a list of (lhs, rhs) tuples.
""" """
pairs, off_ = _read_pairs(line) pairs, off_ = _read_pairs(line)
return odict.ODict(pairs) return pairs
def format_cookie_header(od): def format_cookie_header(lst):
""" """
Formats a Cookie header value. Formats a Cookie header value.
""" """
return _format_pairs(od.lst) return _format_pairs(lst)
def refresh_set_cookie_header(c, delta): def refresh_set_cookie_header(c, delta):
@ -209,10 +243,10 @@ def refresh_set_cookie_header(c, delta):
raise ValueError("Invalid Cookie") raise ValueError("Invalid Cookie")
if "expires" in attrs: if "expires" in attrs:
e = parsedate_tz(attrs["expires"][-1]) e = parsedate_tz(attrs["expires"])
if e: if e:
f = mktime_tz(e) + delta f = mktime_tz(e) + delta
attrs["expires"] = [formatdate(f)] attrs = attrs.with_set_all("expires", [formatdate(f)])
else: else:
# This can happen when the expires tag is invalid. # This can happen when the expires tag is invalid.
# reddit.com sends a an expires tag like this: "Thu, 31 Dec # reddit.com sends a an expires tag like this: "Thu, 31 Dec
@ -220,7 +254,7 @@ def refresh_set_cookie_header(c, delta):
# strictly correct according to the cookie spec. Browsers # strictly correct according to the cookie spec. Browsers
# appear to parse this tolerantly - maybe we should too. # appear to parse this tolerantly - maybe we should too.
# For now, we just ignore this. # For now, we just ignore this.
del attrs["expires"] attrs = attrs.with_delitem("expires")
ret = format_set_cookie_header(name, value, attrs) ret = format_set_cookie_header(name, value, attrs)
if not ret: if not ret:

View File

@ -1,9 +1,3 @@
"""
Unicode Handling
----------------
See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
"""
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import re import re
@ -13,23 +7,22 @@ try:
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
from collections import MutableMapping # Workaround for Python < 3.3 from collections import MutableMapping # Workaround for Python < 3.3
import six import six
from ..multidict import MultiDict
from ..utils import always_bytes
from netlib.utils import always_byte_args, always_bytes, Serializable # See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
if six.PY2: # pragma: no cover if six.PY2: # pragma: no cover
_native = lambda x: x _native = lambda x: x
_always_bytes = lambda x: x _always_bytes = lambda x: x
_always_byte_args = lambda x: x
else: else:
# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
_native = lambda x: x.decode("utf-8", "surrogateescape") _native = lambda x: x.decode("utf-8", "surrogateescape")
_always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape")
_always_byte_args = always_byte_args("utf-8", "surrogateescape")
class Headers(MutableMapping, Serializable): class Headers(MultiDict):
""" """
Header class which allows both convenient access to individual headers as well as Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface. direct access to the underlying raw data. Provides a full dictionary interface.
@ -49,11 +42,11 @@ class Headers(MutableMapping, Serializable):
>>> h["host"] >>> h["host"]
"example.com" "example.com"
# Headers can also be creatd from a list of raw (header_name, header_value) byte tuples # Headers can also be created from a list of raw (header_name, header_value) byte tuples
>>> h = Headers([ >>> h = Headers([
[b"Host",b"example.com"], (b"Host",b"example.com"),
[b"Accept",b"text/html"], (b"Accept",b"text/html"),
[b"accept",b"application/xml"] (b"accept",b"application/xml")
]) ])
# Multiple headers are folded into a single header as per RFC7230 # Multiple headers are folded into a single header as per RFC7230
@ -77,7 +70,6 @@ class Headers(MutableMapping, Serializable):
For use with the "Set-Cookie" header, see :py:meth:`get_all`. For use with the "Set-Cookie" header, see :py:meth:`get_all`.
""" """
@_always_byte_args
def __init__(self, fields=None, **headers): def __init__(self, fields=None, **headers):
""" """
Args: Args:
@ -89,19 +81,29 @@ class Headers(MutableMapping, Serializable):
If ``**headers`` contains multiple keys that have equal ``.lower()`` s, If ``**headers`` contains multiple keys that have equal ``.lower()`` s,
the behavior is undefined. the behavior is undefined.
""" """
self.fields = fields or [] super(Headers, self).__init__(fields)
for name, value in self.fields: for key, value in self.fields:
if not isinstance(name, bytes) or not isinstance(value, bytes): if not isinstance(key, bytes) or not isinstance(value, bytes):
raise ValueError("Headers passed as fields must be bytes.") raise TypeError("Header fields must be bytes.")
# content_type -> content-type # content_type -> content-type
headers = { headers = {
_always_bytes(name).replace(b"_", b"-"): value _always_bytes(name).replace(b"_", b"-"): _always_bytes(value)
for name, value in six.iteritems(headers) for name, value in six.iteritems(headers)
} }
self.update(headers) self.update(headers)
@staticmethod
def _reduce_values(values):
# Headers can be folded
return ", ".join(values)
@staticmethod
def _kconv(key):
# Headers are case-insensitive
return key.lower()
def __bytes__(self): def __bytes__(self):
if self.fields: if self.fields:
return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n"
@ -111,98 +113,40 @@ class Headers(MutableMapping, Serializable):
if six.PY2: # pragma: no cover if six.PY2: # pragma: no cover
__str__ = __bytes__ __str__ = __bytes__
@_always_byte_args def __delitem__(self, key):
def __getitem__(self, name): key = _always_bytes(key)
values = self.get_all(name) super(Headers, self).__delitem__(key)
if not values:
raise KeyError(name)
return ", ".join(values)
@_always_byte_args
def __setitem__(self, name, value):
idx = self._index(name)
# To please the human eye, we insert at the same position the first existing header occured.
if idx is not None:
del self[name]
self.fields.insert(idx, [name, value])
else:
self.fields.append([name, value])
@_always_byte_args
def __delitem__(self, name):
if name not in self:
raise KeyError(name)
name = name.lower()
self.fields = [
field for field in self.fields
if name != field[0].lower()
]
def __iter__(self): def __iter__(self):
seen = set() for x in super(Headers, self).__iter__():
for name, _ in self.fields: yield _native(x)
name_lower = name.lower()
if name_lower not in seen:
seen.add(name_lower)
yield _native(name)
def __len__(self):
return len(set(name.lower() for name, _ in self.fields))
# __hash__ = object.__hash__
def _index(self, name):
name = name.lower()
for i, field in enumerate(self.fields):
if field[0].lower() == name:
return i
return None
def __eq__(self, other):
if isinstance(other, Headers):
return self.fields == other.fields
return False
def __ne__(self, other):
return not self.__eq__(other)
@_always_byte_args
def get_all(self, name): def get_all(self, name):
""" """
Like :py:meth:`get`, but does not fold multiple headers into a single one. Like :py:meth:`get`, but does not fold multiple headers into a single one.
This is useful for Set-Cookie headers, which do not support folding. This is useful for Set-Cookie headers, which do not support folding.
See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 See also: https://tools.ietf.org/html/rfc7230#section-3.2.2
""" """
name_lower = name.lower() name = _always_bytes(name)
values = [_native(value) for n, value in self.fields if n.lower() == name_lower] return [
return values _native(x) for x in
super(Headers, self).get_all(name)
]
@_always_byte_args
def set_all(self, name, values): def set_all(self, name, values):
""" """
Explicitly set multiple headers for the given key. Explicitly set multiple headers for the given key.
See: :py:meth:`get_all` See: :py:meth:`get_all`
""" """
values = map(_always_bytes, values) # _always_byte_args does not fix lists name = _always_bytes(name)
if name in self: values = [_always_bytes(x) for x in values]
del self[name] return super(Headers, self).set_all(name, values)
self.fields.extend(
[name, value] for value in values
)
def get_state(self): def insert(self, index, key, value):
return tuple(tuple(field) for field in self.fields) key = _always_bytes(key)
value = _always_bytes(value)
super(Headers, self).insert(index, key, value)
def set_state(self, state):
self.fields = [list(field) for field in state]
@classmethod
def from_state(cls, state):
return cls([list(field) for field in state])
@_always_byte_args
def replace(self, pattern, repl, flags=0): def replace(self, pattern, repl, flags=0):
""" """
Replaces a regular expression pattern with repl in each "name: value" Replaces a regular expression pattern with repl in each "name: value"
@ -211,6 +155,8 @@ class Headers(MutableMapping, Serializable):
Returns: Returns:
The number of replacements made. The number of replacements made.
""" """
pattern = _always_bytes(pattern)
repl = _always_bytes(repl)
pattern = re.compile(pattern, flags) pattern = re.compile(pattern, flags)
replacements = 0 replacements = 0

View File

@ -316,14 +316,14 @@ def _read_headers(rfile):
if not ret: if not ret:
raise HttpSyntaxException("Invalid headers") raise HttpSyntaxException("Invalid headers")
# continued header # continued header
ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())
else: else:
try: try:
name, value = line.split(b":", 1) name, value = line.split(b":", 1)
value = value.strip() value = value.strip()
if not name: if not name:
raise ValueError() raise ValueError()
ret.append([name, value]) ret.append((name, value))
except ValueError: except ValueError:
raise HttpSyntaxException("Invalid headers") raise HttpSyntaxException("Invalid headers")
return Headers(ret) return Headers(ret)

View File

@ -201,13 +201,13 @@ class HTTP2Protocol(object):
headers = request.headers.copy() headers = request.headers.copy()
if ':authority' not in headers: if ':authority' not in headers:
headers.fields.insert(0, (b':authority', authority.encode('ascii'))) headers.insert(0, b':authority', authority.encode('ascii'))
if ':scheme' not in headers: if ':scheme' not in headers:
headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) headers.insert(0, b':scheme', request.scheme.encode('ascii'))
if ':path' not in headers: if ':path' not in headers:
headers.fields.insert(0, (b':path', request.path.encode('ascii'))) headers.insert(0, b':path', request.path.encode('ascii'))
if ':method' not in headers: if ':method' not in headers:
headers.fields.insert(0, (b':method', request.method.encode('ascii'))) headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'): if hasattr(request, 'stream_id'):
stream_id = request.stream_id stream_id = request.stream_id
@ -224,7 +224,7 @@ class HTTP2Protocol(object):
headers = response.headers.copy() headers = response.headers.copy()
if ':status' not in headers: if ':status' not in headers:
headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) headers.insert(0, b':status', str(response.status_code).encode('ascii'))
if hasattr(response, 'stream_id'): if hasattr(response, 'stream_id'):
stream_id = response.stream_id stream_id = response.stream_id
@ -420,7 +420,7 @@ class HTTP2Protocol(object):
self._handle_unexpected_frame(frm) self._handle_unexpected_frame(frm)
headers = Headers( headers = Headers(
[[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
) )
return stream_id, headers, body return stream_id, headers, body

View File

@ -4,6 +4,7 @@ import warnings
import six import six
from ..multidict import MultiDict
from .headers import Headers from .headers import Headers
from .. import encoding, utils from .. import encoding, utils

View File

@ -10,6 +10,7 @@ from netlib import utils
from netlib.http import cookies from netlib.http import cookies
from netlib.odict import ODict from netlib.odict import ODict
from .. import encoding from .. import encoding
from ..multidict import MultiDictView
from .headers import Headers from .headers import Headers
from .message import Message, _native, _always_bytes, MessageData from .message import Message, _native, _always_bytes, MessageData
@ -224,45 +225,64 @@ class Request(Message):
@property @property
def query(self): def query(self):
# type: () -> MultiDictView
""" """
The request query string as an :py:class:`ODict` object. The request query string as an :py:class:`MultiDictView` object.
None, if there is no query.
""" """
_, _, _, _, query, _ = urllib.parse.urlparse(self.url) return MultiDictView(
if query: self._get_query,
return ODict(utils.urldecode(query)) self._set_query
return None )
@query.setter def _get_query(self):
def query(self, odict): _, _, _, _, query, _ = urllib.parse.urlparse(self.url)
query = utils.urlencode(odict.lst) return tuple(utils.urldecode(query))
def _set_query(self, value):
query = utils.urlencode(value)
scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
_, _, _, self.path = utils.parse_url( _, _, _, self.path = utils.parse_url(
urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
@query.setter
def query(self, value):
self._set_query(value)
@property @property
def cookies(self): def cookies(self):
# type: () -> MultiDictView
""" """
The request cookies. The request cookies.
An empty :py:class:`ODict` object if the cookie monster ate them all.
An empty :py:class:`MultiDictView` object if the cookie monster ate them all.
""" """
ret = ODict() return MultiDictView(
for i in self.headers.get_all("Cookie"): self._get_cookies,
ret.extend(cookies.parse_cookie_header(i)) self._set_cookies
return ret )
def _get_cookies(self):
h = self.headers.get_all("Cookie")
return tuple(cookies.parse_cookie_headers(h))
def _set_cookies(self, value):
self.headers["cookie"] = cookies.format_cookie_header(value)
@cookies.setter @cookies.setter
def cookies(self, odict): def cookies(self, value):
self.headers["cookie"] = cookies.format_cookie_header(odict) self._set_cookies(value)
@property @property
def path_components(self): def path_components(self):
""" """
The URL's path components as a list of strings. The URL's path components as a tuple of strings.
Components are unquoted. Components are unquoted.
""" """
_, _, path, _, _, _ = urllib.parse.urlparse(self.url) _, _, path, _, _, _ = urllib.parse.urlparse(self.url)
return [urllib.parse.unquote(i) for i in path.split("/") if i] # This needs to be a tuple so that it's immutable.
# Otherwise, this would fail silently:
# request.path_components.append("foo")
return tuple(urllib.parse.unquote(i) for i in path.split("/") if i)
@path_components.setter @path_components.setter
def path_components(self, components): def path_components(self, components):
@ -309,64 +329,53 @@ class Request(Message):
@property @property
def urlencoded_form(self): def urlencoded_form(self):
""" """
The URL-encoded form data as an :py:class:`ODict` object. The URL-encoded form data as an :py:class:`MultiDictView` object.
None if there is no data or the content-type indicates non-form data. An empty MultiDictView if the content-type indicates non-form data
or the content could not be parsed.
""" """
is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() return MultiDictView(
if self.content and is_valid_content_type: self._get_urlencoded_form,
return ODict(utils.urldecode(self.content)) self._set_urlencoded_form
return None )
@urlencoded_form.setter def _get_urlencoded_form(self):
def urlencoded_form(self, odict): is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
if is_valid_content_type:
return tuple(utils.urldecode(self.content))
return ()
def _set_urlencoded_form(self, value):
""" """
Sets the body to the URL-encoded form data, and adds the appropriate content-type header. Sets the body to the URL-encoded form data, and adds the appropriate content-type header.
This will overwrite the existing content if there is one. This will overwrite the existing content if there is one.
""" """
self.headers["content-type"] = "application/x-www-form-urlencoded" self.headers["content-type"] = "application/x-www-form-urlencoded"
self.content = utils.urlencode(odict.lst) self.content = utils.urlencode(value)
@urlencoded_form.setter
def urlencoded_form(self, value):
self._set_urlencoded_form(value)
@property @property
def multipart_form(self): def multipart_form(self):
""" """
The multipart form data as an :py:class:`ODict` object. The multipart form data as an :py:class:`MultipartFormDict` object.
None if there is no data or the content-type indicates non-form data. None if the content-type indicates non-form data.
""" """
return MultiDictView(
self._get_multipart_form,
self._set_multipart_form
)
def _get_multipart_form(self):
is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
if self.content and is_valid_content_type: if is_valid_content_type:
return ODict(utils.multipartdecode(self.headers,self.content)) return utils.multipartdecode(self.headers, self.content)
return None return ()
def _set_multipart_form(self, value):
raise NotImplementedError()
@multipart_form.setter @multipart_form.setter
def multipart_form(self, value): def multipart_form(self, value):
raise NotImplementedError() self._set_multipart_form(value)
# Legacy
def get_query(self): # pragma: no cover
warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning)
return self.query or ODict([])
def set_query(self, odict): # pragma: no cover
warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning)
self.query = odict
def get_path_components(self): # pragma: no cover
warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning)
return self.path_components
def set_path_components(self, lst): # pragma: no cover
warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning)
self.path_components = lst
def get_form_urlencoded(self): # pragma: no cover
warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning)
return self.urlencoded_form or ODict([])
def set_form_urlencoded(self, odict): # pragma: no cover
warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning)
self.urlencoded_form = odict
def get_form_multipart(self): # pragma: no cover
warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning)
return self.multipart_form or ODict([])

View File

@ -1,14 +1,13 @@
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import warnings
from email.utils import parsedate_tz, formatdate, mktime_tz from email.utils import parsedate_tz, formatdate, mktime_tz
import time import time
from . import cookies from . import cookies
from .headers import Headers from .headers import Headers
from .message import Message, _native, _always_bytes, MessageData from .message import Message, _native, _always_bytes, MessageData
from ..multidict import MultiDictView
from .. import utils from .. import utils
from ..odict import ODict
class ResponseData(MessageData): class ResponseData(MessageData):
@ -72,29 +71,35 @@ class Response(Message):
@property @property
def cookies(self): def cookies(self):
# type: () -> MultiDictView
""" """
Get the contents of all Set-Cookie headers. The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are
cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is
an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly)
are indicated by a Null value.
A possibly empty :py:class:`ODict`, where keys are cookie name strings, Caveats:
and values are [value, attr] lists. Value is a string, and attr is Updating the attr
an ODictCaseless containing cookie attributes. Within attrs, unary
attributes (e.g. HTTPOnly) are indicated by a Null value.
""" """
ret = [] return MultiDictView(
for header in self.headers.get_all("set-cookie"): self._get_cookies,
v = cookies.parse_set_cookie_header(header) self._set_cookies
if v: )
name, value, attrs = v
ret.append([name, [value, attrs]]) def _get_cookies(self):
return ODict(ret) h = self.headers.get_all("set-cookie")
return tuple(cookies.parse_set_cookie_headers(h))
def _set_cookies(self, value):
cookie_headers = []
for k, v in value:
header = cookies.format_set_cookie_header(k, v[0], v[1])
cookie_headers.append(header)
self.headers.set_all("set-cookie", cookie_headers)
@cookies.setter @cookies.setter
def cookies(self, odict): def cookies(self, value):
values = [] self._set_cookies(value)
for i in odict.lst:
header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1])
values.append(header)
self.headers.set_all("set-cookie", values)
def refresh(self, now=None): def refresh(self, now=None):
""" """

279
netlib/multidict.py Normal file
View File

@ -0,0 +1,279 @@
from __future__ import absolute_import, print_function, division
from abc import ABCMeta, abstractmethod
from typing import Tuple, TypeVar
try:
from collections.abc import MutableMapping
except ImportError: # pragma: no cover
from collections import MutableMapping # Workaround for Python < 3.3
import six
from .utils import Serializable
@six.add_metaclass(ABCMeta)
class _MultiDict(MutableMapping, Serializable):
def __repr__(self):
fields = tuple(
repr(field)
for field in self.fields
)
return "{cls}[{fields}]".format(
cls=type(self).__name__,
fields=", ".join(fields)
)
@staticmethod
@abstractmethod
def _reduce_values(values):
"""
If a user accesses multidict["foo"], this method
reduces all values for "foo" to a single value that is returned.
For example, HTTP headers are folded, whereas we will just take
the first cookie we found with that name.
"""
@staticmethod
@abstractmethod
def _kconv(key):
"""
This method converts a key to its canonical representation.
For example, HTTP headers are case-insensitive, so this method returns key.lower().
"""
def __getitem__(self, key):
values = self.get_all(key)
if not values:
raise KeyError(key)
return self._reduce_values(values)
def __setitem__(self, key, value):
self.set_all(key, [value])
def __delitem__(self, key):
if key not in self:
raise KeyError(key)
key = self._kconv(key)
self.fields = tuple(
field for field in self.fields
if key != self._kconv(field[0])
)
def __iter__(self):
seen = set()
for key, _ in self.fields:
key_kconv = self._kconv(key)
if key_kconv not in seen:
seen.add(key_kconv)
yield key
def __len__(self):
return len(set(self._kconv(key) for key, _ in self.fields))
def __eq__(self, other):
if isinstance(other, MultiDict):
return self.fields == other.fields
return False
def __ne__(self, other):
return not self.__eq__(other)
def get_all(self, key):
"""
Return the list of all values for a given key.
If that key is not in the MultiDict, the return value will be an empty list.
"""
key = self._kconv(key)
return [
value
for k, value in self.fields
if self._kconv(k) == key
]
def set_all(self, key, values):
"""
Remove the old values for a key and add new ones.
"""
key_kconv = self._kconv(key)
new_fields = []
for field in self.fields:
if self._kconv(field[0]) == key_kconv:
if values:
new_fields.append(
(key, values.pop(0))
)
else:
new_fields.append(field)
while values:
new_fields.append(
(key, values.pop(0))
)
self.fields = tuple(new_fields)
def add(self, key, value):
"""
Add an additional value for the given key at the bottom.
"""
self.insert(len(self.fields), key, value)
def insert(self, index, key, value):
"""
Insert an additional value for the given key at the specified position.
"""
item = (key, value)
self.fields = self.fields[:index] + (item,) + self.fields[index:]
def keys(self, multi=False):
"""
Get all keys.
Args:
multi(bool):
If True, one key per value will be returned.
If False, duplicate keys will only be returned once.
"""
return (
k
for k, _ in self.items(multi)
)
def values(self, multi=False):
"""
Get all values.
Args:
multi(bool):
If True, all values will be returned.
If False, only the first value per key will be returned.
"""
return (
v
for _, v in self.items(multi)
)
def items(self, multi=False):
"""
Get all (key, value) tuples.
Args:
multi(bool):
If True, all (key, value) pairs will be returned
If False, only the first (key, value) pair per unique key will be returned.
"""
if multi:
return self.fields
else:
return super(_MultiDict, self).items()
def to_dict(self):
"""
Get the MultiDict as a plain Python dict.
Keys with multiple values are returned as lists.
Example:
.. code-block:: python
# Simple dict with duplicate values.
>>> d
MultiDictView[("name", "value"), ("a", "false"), ("a", "42")]
>>> d.to_dict()
{
"name": "value",
"a": ["false", "42"]
}
"""
d = {}
for key in self:
values = self.get_all(key)
if len(values) == 1:
d[key] = values[0]
else:
d[key] = values
return d
def get_state(self):
return self.fields
def set_state(self, state):
self.fields = tuple(tuple(x) for x in state)
@classmethod
def from_state(cls, state):
return cls(tuple(x) for x in state)
class MultiDict(_MultiDict):
def __init__(self, fields=None):
super(MultiDict, self).__init__()
self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...]
@six.add_metaclass(ABCMeta)
class ImmutableMultiDict(MultiDict):
def _immutable(self, *_):
raise TypeError('{} objects are immutable'.format(self.__class__.__name__))
__delitem__ = set_all = insert = _immutable
def with_delitem(self, key):
"""
Returns:
An updated ImmutableMultiDict. The original object will not be modified.
"""
ret = self.copy()
super(ImmutableMultiDict, ret).__delitem__(key)
return ret
def with_set_all(self, key, values):
"""
Returns:
An updated ImmutableMultiDict. The original object will not be modified.
"""
ret = self.copy()
super(ImmutableMultiDict, ret).set_all(key, values)
return ret
def with_insert(self, index, key, value):
"""
Returns:
An updated ImmutableMultiDict. The original object will not be modified.
"""
ret = self.copy()
super(ImmutableMultiDict, ret).insert(index, key, value)
return ret
class MultiDictView(_MultiDict):
"""
The MultiDictView provides the MultiDict interface over calculated data.
The view itself contains no state - data is retrieved from the parent on
request, and stored back to the parent on change.
"""
def __init__(self, getter, setter):
self._getter = getter
self._setter = setter
super(MultiDictView, self).__init__()
@staticmethod
def _kconv(key):
# All request-attributes are case-sensitive.
return key
@staticmethod
def _reduce_values(values):
# We just return the first element if
# multiple elements exist with the same key.
return values[0]
@property
def fields(self):
return self._getter()
@fields.setter
def fields(self, value):
return self._setter(value)

View File

@ -51,17 +51,6 @@ def always_bytes(unicode_or_bytes, *encode_args):
return unicode_or_bytes return unicode_or_bytes
def always_byte_args(*encode_args):
"""Decorator that transparently encodes all arguments passed as unicode"""
def decorator(fun):
def _fun(*args, **kwargs):
args = [always_bytes(arg, *encode_args) for arg in args]
kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)}
return fun(*args, **kwargs)
return _fun
return decorator
def native(s, *encoding_opts): def native(s, *encoding_opts):
""" """
Convert :py:class:`bytes` or :py:class:`unicode` to the native Convert :py:class:`bytes` or :py:class:`unicode` to the native

View File

@ -95,14 +95,22 @@ def test_modify_form():
flow = tutils.tflow(req=netutils.treq(headers=form_header)) flow = tutils.tflow(req=netutils.treq(headers=form_header))
with example("modify_form.py") as ex: with example("modify_form.py") as ex:
ex.run("request", flow) ex.run("request", flow)
assert flow.request.urlencoded_form["mitmproxy"] == ["rocks"] assert flow.request.urlencoded_form["mitmproxy"] == "rocks"
flow.request.headers["content-type"] = ""
ex.run("request", flow)
assert list(flow.request.urlencoded_form.items()) == [("foo", "bar")]
def test_modify_querystring(): def test_modify_querystring():
flow = tutils.tflow(req=netutils.treq(path="/search?q=term")) flow = tutils.tflow(req=netutils.treq(path="/search?q=term"))
with example("modify_querystring.py") as ex: with example("modify_querystring.py") as ex:
ex.run("request", flow) ex.run("request", flow)
assert flow.request.query["mitmproxy"] == ["rocks"] assert flow.request.query["mitmproxy"] == "rocks"
flow.request.path = "/"
ex.run("request", flow)
assert flow.request.query["mitmproxy"] == "rocks"
def test_modify_response_body(): def test_modify_response_body():

View File

@ -1067,60 +1067,6 @@ class TestRequest:
assert r.url == "https://address:22/path" assert r.url == "https://address:22/path"
assert r.pretty_url == "https://foo.com:22/path" assert r.pretty_url == "https://foo.com:22/path"
def test_path_components(self):
r = HTTPRequest.wrap(netlib.tutils.treq())
r.path = "/"
assert r.get_path_components() == []
r.path = "/foo/bar"
assert r.get_path_components() == ["foo", "bar"]
q = odict.ODict()
q["test"] = ["123"]
r.set_query(q)
assert r.get_path_components() == ["foo", "bar"]
r.set_path_components([])
assert r.get_path_components() == []
r.set_path_components(["foo"])
assert r.get_path_components() == ["foo"]
r.set_path_components(["/oo"])
assert r.get_path_components() == ["/oo"]
assert "%2F" in r.path
def test_getset_form_urlencoded(self):
d = odict.ODict([("one", "two"), ("three", "four")])
r = HTTPRequest.wrap(netlib.tutils.treq(content=netlib.utils.urlencode(d.lst)))
r.headers["content-type"] = "application/x-www-form-urlencoded"
assert r.get_form_urlencoded() == d
d = odict.ODict([("x", "y")])
r.set_form_urlencoded(d)
assert r.get_form_urlencoded() == d
r.headers["content-type"] = "foo"
assert not r.get_form_urlencoded()
def test_getset_query(self):
r = HTTPRequest.wrap(netlib.tutils.treq())
r.path = "/foo?x=y&a=b"
q = r.get_query()
assert q.lst == [("x", "y"), ("a", "b")]
r.path = "/"
q = r.get_query()
assert not q
r.path = "/?adsfa"
q = r.get_query()
assert q.lst == [("adsfa", "")]
r.path = "/foo?x=y&a=b"
assert r.get_query()
r.set_query(odict.ODict([]))
assert not r.get_query()
qv = odict.ODict([("a", "b"), ("c", "d")])
r.set_query(qv)
assert r.get_query() == qv
def test_anticache(self): def test_anticache(self):
r = HTTPRequest.wrap(netlib.tutils.treq()) r = HTTPRequest.wrap(netlib.tutils.treq())
r.headers = Headers() r.headers = Headers()

View File

@ -21,7 +21,7 @@ def python_equals(testdata, text):
assert clean_blanks(text).rstrip() == clean_blanks(d).rstrip() assert clean_blanks(text).rstrip() == clean_blanks(d).rstrip()
req_get = lambda: netlib.tutils.treq(method='GET', content='') req_get = lambda: netlib.tutils.treq(method='GET', content='', path=b"/path?a=foo&a=bar&b=baz")
req_post = lambda: netlib.tutils.treq(method='POST', headers=None) req_post = lambda: netlib.tutils.treq(method='POST', headers=None)
@ -31,7 +31,7 @@ req_patch = lambda: netlib.tutils.treq(method='PATCH', path=b"/path?query=param"
class TestExportCurlCommand(): class TestExportCurlCommand():
def test_get(self): def test_get(self):
flow = tutils.tflow(req=req_get()) flow = tutils.tflow(req=req_get())
result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path'""" result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path?a=foo&a=bar&b=baz'"""
assert flow_export.curl_command(flow) == result assert flow_export.curl_command(flow) == result
def test_post(self): def test_post(self):
@ -70,7 +70,7 @@ class TestRawRequest():
def test_get(self): def test_get(self):
flow = tutils.tflow(req=req_get()) flow = tutils.tflow(req=req_get())
result = dedent(""" result = dedent("""
GET /path HTTP/1.1\r GET /path?a=foo&a=bar&b=baz HTTP/1.1\r
header: qvalue\r header: qvalue\r
content-length: 7\r content-length: 7\r
host: address:22\r host: address:22\r

View File

@ -14,10 +14,16 @@ class UserBehavior(TaskSet):
'content-length': '7', 'content-length': '7',
} }
params = {
'a': ['foo', 'bar'],
'b': 'baz',
}
self.response = self.client.request( self.response = self.client.request(
method='GET', method='GET',
url=url, url=url,
headers=headers, headers=headers,
params=params,
) )
### Additional tasks can go here ### ### Additional tasks can go here ###

View File

@ -7,8 +7,14 @@
'content-length': '7', 'content-length': '7',
} }
params = {
'a': ['foo', 'bar'],
'b': 'baz',
}
self.response = self.client.request( self.response = self.client.request(
method='GET', method='GET',
url=url, url=url,
headers=headers, headers=headers,
params=params,
) )

View File

@ -7,10 +7,16 @@ headers = {
'content-length': '7', 'content-length': '7',
} }
params = {
'a': ['foo', 'bar'],
'b': 'baz',
}
response = requests.request( response = requests.request(
method='GET', method='GET',
url=url, url=url,
headers=headers, headers=headers,
params=params,
) )
print(response.text) print(response.text)

View File

@ -261,7 +261,7 @@ class TestReadHeaders(object):
b"\r\n" b"\r\n"
) )
headers = self._read(data) headers = self._read(data)
assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]] assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two"))
def test_read_multi(self): def test_read_multi(self):
data = ( data = (
@ -270,7 +270,7 @@ class TestReadHeaders(object):
b"\r\n" b"\r\n"
) )
headers = self._read(data) headers = self._read(data)
assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]] assert headers.fields == ((b"Header", b"one"), (b"Header", b"two"))
def test_read_continued(self): def test_read_continued(self):
data = ( data = (
@ -280,7 +280,7 @@ class TestReadHeaders(object):
b"\r\n" b"\r\n"
) )
headers = self._read(data) headers = self._read(data)
assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]] assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three"))
def test_read_continued_err(self): def test_read_continued_err(self):
data = b"\tfoo: bar\r\n" data = b"\tfoo: bar\r\n"
@ -300,7 +300,7 @@ class TestReadHeaders(object):
def test_read_empty_value(self): def test_read_empty_value(self):
data = b"bar:" data = b"bar:"
headers = self._read(data) headers = self._read(data)
assert headers.fields == [[b"bar", b""]] assert headers.fields == ((b"bar", b""),)
def test_read_chunked(): def test_read_chunked():
req = treq(content=None) req = treq(content=None)

View File

@ -312,7 +312,7 @@ class TestReadRequest(tservers.ServerTestBase):
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
assert req.stream_id assert req.stream_id
assert req.headers.fields == [[b':method', b'GET'], [b':path', b'/'], [b':scheme', b'https']] assert req.headers.fields == ((b':method', b'GET'), (b':path', b'/'), (b':scheme', b'https'))
assert req.content == b'foobar' assert req.content == b'foobar'
@ -418,7 +418,7 @@ class TestReadResponse(tservers.ServerTestBase):
assert resp.http_version == "HTTP/2.0" assert resp.http_version == "HTTP/2.0"
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.reason == '' assert resp.reason == ''
assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
assert resp.content == b'foobar' assert resp.content == b'foobar'
assert resp.timestamp_end assert resp.timestamp_end
@ -445,7 +445,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
assert resp.http_version == "HTTP/2.0" assert resp.http_version == "HTTP/2.0"
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.reason == '' assert resp.reason == ''
assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
assert resp.content == b'' assert resp.content == b''

View File

@ -128,10 +128,10 @@ def test_cookie_roundtrips():
] ]
for s, lst in pairs: for s, lst in pairs:
ret = cookies.parse_cookie_header(s) ret = cookies.parse_cookie_header(s)
assert ret.lst == lst assert ret == lst
s2 = cookies.format_cookie_header(ret) s2 = cookies.format_cookie_header(ret)
ret = cookies.parse_cookie_header(s2) ret = cookies.parse_cookie_header(s2)
assert ret.lst == lst assert ret == lst
def test_parse_set_cookie_pairs(): def test_parse_set_cookie_pairs():
@ -197,24 +197,28 @@ def test_parse_set_cookie_header():
], ],
[ [
"one=uno", "one=uno",
("one", "uno", []) ("one", "uno", ())
], ],
[ [
"one=uno; foo=bar", "one=uno; foo=bar",
("one", "uno", [["foo", "bar"]]) ("one", "uno", (("foo", "bar"),))
] ],
[
"one=uno; foo=bar; foo=baz",
("one", "uno", (("foo", "bar"), ("foo", "baz")))
],
] ]
for s, expected in vals: for s, expected in vals:
ret = cookies.parse_set_cookie_header(s) ret = cookies.parse_set_cookie_header(s)
if expected: if expected:
assert ret[0] == expected[0] assert ret[0] == expected[0]
assert ret[1] == expected[1] assert ret[1] == expected[1]
assert ret[2].lst == expected[2] assert ret[2].items(multi=True) == expected[2]
s2 = cookies.format_set_cookie_header(*ret) s2 = cookies.format_set_cookie_header(*ret)
ret2 = cookies.parse_set_cookie_header(s2) ret2 = cookies.parse_set_cookie_header(s2)
assert ret2[0] == expected[0] assert ret2[0] == expected[0]
assert ret2[1] == expected[1] assert ret2[1] == expected[1]
assert ret2[2].lst == expected[2] assert ret2[2].items(multi=True) == expected[2]
else: else:
assert ret is None assert ret is None

View File

@ -5,10 +5,10 @@ from netlib.tutils import raises
class TestHeaders(object): class TestHeaders(object):
def _2host(self): def _2host(self):
return Headers( return Headers(
[ (
[b"Host", b"example.com"], (b"Host", b"example.com"),
[b"host", b"example.org"] (b"host", b"example.org")
] )
) )
def test_init(self): def test_init(self):
@ -38,20 +38,10 @@ class TestHeaders(object):
assert headers["Host"] == "example.com" assert headers["Host"] == "example.com"
assert headers["Accept"] == "text/plain" assert headers["Accept"] == "text/plain"
with raises(ValueError): with raises(TypeError):
Headers([[b"Host", u"not-bytes"]]) Headers([[b"Host", u"not-bytes"]])
def test_getitem(self): def test_bytes(self):
headers = Headers(Host="example.com")
assert headers["Host"] == "example.com"
assert headers["host"] == "example.com"
with raises(KeyError):
_ = headers["Accept"]
headers = self._2host()
assert headers["Host"] == "example.com, example.org"
def test_str(self):
headers = Headers(Host="example.com") headers = Headers(Host="example.com")
assert bytes(headers) == b"Host: example.com\r\n" assert bytes(headers) == b"Host: example.com\r\n"
@ -64,93 +54,6 @@ class TestHeaders(object):
headers = Headers() headers = Headers()
assert bytes(headers) == b"" assert bytes(headers) == b""
def test_setitem(self):
headers = Headers()
headers["Host"] = "example.com"
assert "Host" in headers
assert "host" in headers
assert headers["Host"] == "example.com"
headers["host"] = "example.org"
assert "Host" in headers
assert "host" in headers
assert headers["Host"] == "example.org"
headers["accept"] = "text/plain"
assert len(headers) == 2
assert "Accept" in headers
assert "Host" in headers
headers = self._2host()
assert len(headers.fields) == 2
headers["Host"] = "example.com"
assert len(headers.fields) == 1
assert "Host" in headers
def test_delitem(self):
headers = Headers(Host="example.com")
assert len(headers) == 1
del headers["host"]
assert len(headers) == 0
try:
del headers["host"]
except KeyError:
assert True
else:
assert False
headers = self._2host()
del headers["Host"]
assert len(headers) == 0
def test_keys(self):
headers = Headers(Host="example.com")
assert list(headers.keys()) == ["Host"]
headers = self._2host()
assert list(headers.keys()) == ["Host"]
def test_eq_ne(self):
headers1 = Headers(Host="example.com")
headers2 = Headers(host="example.com")
assert not (headers1 == headers2)
assert headers1 != headers2
headers1 = Headers(Host="example.com")
headers2 = Headers(Host="example.com")
assert headers1 == headers2
assert not (headers1 != headers2)
assert headers1 != 42
def test_get_all(self):
headers = self._2host()
assert headers.get_all("host") == ["example.com", "example.org"]
assert headers.get_all("accept") == []
def test_set_all(self):
headers = Headers(Host="example.com")
headers.set_all("Accept", ["text/plain"])
assert len(headers) == 2
assert "accept" in headers
headers = self._2host()
headers.set_all("Host", ["example.org"])
assert headers["host"] == "example.org"
headers.set_all("Host", ["example.org", "example.net"])
assert headers["host"] == "example.org, example.net"
def test_state(self):
headers = self._2host()
assert len(headers.get_state()) == 2
assert headers == Headers.from_state(headers.get_state())
headers2 = Headers()
assert headers != headers2
headers2.set_state(headers.get_state())
assert headers == headers2
def test_replace_simple(self): def test_replace_simple(self):
headers = Headers(Host="example.com", Accept="text/plain") headers = Headers(Host="example.com", Accept="text/plain")
replacements = headers.replace("Host: ", "X-Host: ") replacements = headers.replace("Host: ", "X-Host: ")

View File

@ -3,16 +3,14 @@ from __future__ import absolute_import, print_function, division
import six import six
from netlib import utils
from netlib.http import Headers from netlib.http import Headers
from netlib.odict import ODict
from netlib.tutils import treq, raises from netlib.tutils import treq, raises
from .test_message import _test_decoded_attr, _test_passthrough_attr from .test_message import _test_decoded_attr, _test_passthrough_attr
class TestRequestData(object): class TestRequestData(object):
def test_init(self): def test_init(self):
with raises(ValueError if six.PY2 else TypeError): with raises(ValueError):
treq(headers="foobar") treq(headers="foobar")
assert isinstance(treq(headers=None).headers, Headers) assert isinstance(treq(headers=None).headers, Headers)
@ -158,16 +156,17 @@ class TestRequestUtils(object):
def test_get_query(self): def test_get_query(self):
request = treq() request = treq()
assert request.query is None assert not request.query
request.url = "http://localhost:80/foo?bar=42" request.url = "http://localhost:80/foo?bar=42"
assert request.query.lst == [("bar", "42")] assert dict(request.query) == {"bar": "42"}
def test_set_query(self): def test_set_query(self):
request = treq(host=b"foo", headers = Headers(host=b"bar")) request = treq()
request.query = ODict([]) assert not request.query
assert request.host == "foo" request.query["foo"] = "bar"
assert request.headers["host"] == "bar" assert request.query["foo"] == "bar"
assert request.path == "/path?foo=bar"
def test_get_cookies_none(self): def test_get_cookies_none(self):
request = treq() request = treq()
@ -177,47 +176,50 @@ class TestRequestUtils(object):
def test_get_cookies_single(self): def test_get_cookies_single(self):
request = treq() request = treq()
request.headers = Headers(cookie="cookiename=cookievalue") request.headers = Headers(cookie="cookiename=cookievalue")
result = request.cookies assert len(request.cookies) == 1
assert len(result) == 1 assert request.cookies['cookiename'] == 'cookievalue'
assert result['cookiename'] == ['cookievalue']
def test_get_cookies_double(self): def test_get_cookies_double(self):
request = treq() request = treq()
request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue")
result = request.cookies result = request.cookies
assert len(result) == 2 assert len(result) == 2
assert result['cookiename'] == ['cookievalue'] assert result['cookiename'] == 'cookievalue'
assert result['othercookiename'] == ['othercookievalue'] assert result['othercookiename'] == 'othercookievalue'
def test_get_cookies_withequalsign(self): def test_get_cookies_withequalsign(self):
request = treq() request = treq()
request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue")
result = request.cookies result = request.cookies
assert len(result) == 2 assert len(result) == 2
assert result['cookiename'] == ['coo=kievalue'] assert result['cookiename'] == 'coo=kievalue'
assert result['othercookiename'] == ['othercookievalue'] assert result['othercookiename'] == 'othercookievalue'
def test_set_cookies(self): def test_set_cookies(self):
request = treq() request = treq()
request.headers = Headers(cookie="cookiename=cookievalue") request.headers = Headers(cookie="cookiename=cookievalue")
result = request.cookies result = request.cookies
result["cookiename"] = ["foo"] result["cookiename"] = "foo"
request.cookies = result assert request.cookies["cookiename"] == "foo"
assert request.cookies["cookiename"] == ["foo"]
def test_get_path_components(self): def test_get_path_components(self):
request = treq(path=b"/foo/bar") request = treq(path=b"/foo/bar")
assert request.path_components == ["foo", "bar"] assert request.path_components == ("foo", "bar")
def test_set_path_components(self): def test_set_path_components(self):
request = treq(host=b"foo", headers = Headers(host=b"bar")) request = treq()
request.path_components = ["foo", "baz"] request.path_components = ["foo", "baz"]
assert request.path == "/foo/baz" assert request.path == "/foo/baz"
request.path_components = [] request.path_components = []
assert request.path == "/" assert request.path == "/"
request.query = ODict([])
assert request.host == "foo" request.path_components = ["foo", "baz"]
assert request.headers["host"] == "bar" request.query["hello"] = "hello"
assert request.path_components == ("foo", "baz")
request.path_components = ["abc"]
assert request.path == "/abc?hello=hello"
def test_anticache(self): def test_anticache(self):
request = treq() request = treq()
@ -246,26 +248,21 @@ class TestRequestUtils(object):
assert "gzip" in request.headers["Accept-Encoding"] assert "gzip" in request.headers["Accept-Encoding"]
def test_get_urlencoded_form(self): def test_get_urlencoded_form(self):
request = treq(content="foobar") request = treq(content="foobar=baz")
assert request.urlencoded_form is None assert not request.urlencoded_form
request.headers["Content-Type"] = "application/x-www-form-urlencoded" request.headers["Content-Type"] = "application/x-www-form-urlencoded"
assert request.urlencoded_form == ODict(utils.urldecode(request.content)) assert list(request.urlencoded_form.items()) == [("foobar", "baz")]
def test_set_urlencoded_form(self): def test_set_urlencoded_form(self):
request = treq() request = treq()
request.urlencoded_form = ODict([('foo', 'bar'), ('rab', 'oof')]) request.urlencoded_form = [('foo', 'bar'), ('rab', 'oof')]
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
assert request.content assert request.content
def test_get_multipart_form(self): def test_get_multipart_form(self):
request = treq(content="foobar") request = treq(content="foobar")
assert request.multipart_form is None assert not request.multipart_form
request.headers["Content-Type"] = "multipart/form-data" request.headers["Content-Type"] = "multipart/form-data"
assert request.multipart_form == ODict( assert list(request.multipart_form.items()) == []
utils.multipartdecode(
request.headers,
request.content
)
)

View File

@ -6,6 +6,7 @@ import six
import time import time
from netlib.http import Headers from netlib.http import Headers
from netlib.http.cookies import CookieAttrs
from netlib.odict import ODict, ODictCaseless from netlib.odict import ODict, ODictCaseless
from netlib.tutils import raises, tresp from netlib.tutils import raises, tresp
from .test_message import _test_passthrough_attr, _test_decoded_attr from .test_message import _test_passthrough_attr, _test_decoded_attr
@ -13,7 +14,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr
class TestResponseData(object): class TestResponseData(object):
def test_init(self): def test_init(self):
with raises(ValueError if six.PY2 else TypeError): with raises(ValueError):
tresp(headers="foobar") tresp(headers="foobar")
assert isinstance(tresp(headers=None).headers, Headers) assert isinstance(tresp(headers=None).headers, Headers)
@ -56,7 +57,7 @@ class TestResponseUtils(object):
result = resp.cookies result = resp.cookies
assert len(result) == 1 assert len(result) == 1
assert "cookiename" in result assert "cookiename" in result
assert result["cookiename"][0] == ["cookievalue", ODict()] assert result["cookiename"] == ("cookievalue", CookieAttrs())
def test_get_cookies_with_parameters(self): def test_get_cookies_with_parameters(self):
resp = tresp() resp = tresp()
@ -64,13 +65,13 @@ class TestResponseUtils(object):
result = resp.cookies result = resp.cookies
assert len(result) == 1 assert len(result) == 1
assert "cookiename" in result assert "cookiename" in result
assert result["cookiename"][0][0] == "cookievalue" assert result["cookiename"][0] == "cookievalue"
attrs = result["cookiename"][0][1] attrs = result["cookiename"][1]
assert len(attrs) == 4 assert len(attrs) == 4
assert attrs["domain"] == ["example.com"] assert attrs["domain"] == "example.com"
assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] assert attrs["expires"] == "Wed Oct 21 16:29:41 2015"
assert attrs["path"] == ["/"] assert attrs["path"] == "/"
assert attrs["httponly"] == [None] assert attrs["httponly"] is None
def test_get_cookies_no_value(self): def test_get_cookies_no_value(self):
resp = tresp() resp = tresp()
@ -78,8 +79,8 @@ class TestResponseUtils(object):
result = resp.cookies result = resp.cookies
assert len(result) == 1 assert len(result) == 1
assert "cookiename" in result assert "cookiename" in result
assert result["cookiename"][0][0] == "" assert result["cookiename"][0] == ""
assert len(result["cookiename"][0][1]) == 2 assert len(result["cookiename"][1]) == 2
def test_get_cookies_twocookies(self): def test_get_cookies_twocookies(self):
resp = tresp() resp = tresp()
@ -90,19 +91,16 @@ class TestResponseUtils(object):
result = resp.cookies result = resp.cookies
assert len(result) == 2 assert len(result) == 2
assert "cookiename" in result assert "cookiename" in result
assert result["cookiename"][0] == ["cookievalue", ODict()] assert result["cookiename"] == ("cookievalue", CookieAttrs())
assert "othercookie" in result assert "othercookie" in result
assert result["othercookie"][0] == ["othervalue", ODict()] assert result["othercookie"] == ("othervalue", CookieAttrs())
def test_set_cookies(self): def test_set_cookies(self):
resp = tresp() resp = tresp()
v = resp.cookies resp.cookies["foo"] = ("bar", {})
v.add("foo", ["bar", ODictCaseless()])
resp.cookies = v
v = resp.cookies assert len(resp.cookies) == 1
assert len(v) == 1 assert resp.cookies["foo"] == ("bar", CookieAttrs())
assert v["foo"] == [["bar", ODictCaseless()]]
def test_refresh(self): def test_refresh(self):
r = tresp() r = tresp()

View File

@ -0,0 +1,239 @@
from netlib import tutils
from netlib.multidict import MultiDict, ImmutableMultiDict, MultiDictView
class _TMulti(object):
@staticmethod
def _reduce_values(values):
return values[0]
@staticmethod
def _kconv(key):
return key.lower()
class TMultiDict(_TMulti, MultiDict):
pass
class TImmutableMultiDict(_TMulti, ImmutableMultiDict):
pass
class TestMultiDict(object):
@staticmethod
def _multi():
return TMultiDict((
("foo", "bar"),
("bar", "baz"),
("Bar", "bam")
))
def test_init(self):
md = TMultiDict()
assert len(md) == 0
md = TMultiDict([("foo", "bar")])
assert len(md) == 1
assert md.fields == (("foo", "bar"),)
def test_repr(self):
assert repr(self._multi()) == (
"TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]"
)
def test_getitem(self):
md = TMultiDict([("foo", "bar")])
assert "foo" in md
assert "Foo" in md
assert md["foo"] == "bar"
with tutils.raises(KeyError):
_ = md["bar"]
md_multi = TMultiDict(
[("foo", "a"), ("foo", "b")]
)
assert md_multi["foo"] == "a"
def test_setitem(self):
md = TMultiDict()
md["foo"] = "bar"
assert md.fields == (("foo", "bar"),)
md["foo"] = "baz"
assert md.fields == (("foo", "baz"),)
md["bar"] = "bam"
assert md.fields == (("foo", "baz"), ("bar", "bam"))
def test_delitem(self):
md = self._multi()
del md["foo"]
assert "foo" not in md
assert "bar" in md
with tutils.raises(KeyError):
del md["foo"]
del md["bar"]
assert md.fields == ()
def test_iter(self):
md = self._multi()
assert list(md.__iter__()) == ["foo", "bar"]
def test_len(self):
md = TMultiDict()
assert len(md) == 0
md = self._multi()
assert len(md) == 2
def test_eq(self):
assert TMultiDict() == TMultiDict()
assert not (TMultiDict() == 42)
md1 = self._multi()
md2 = self._multi()
assert md1 == md2
md1.fields = md1.fields[1:] + md1.fields[:1]
assert not (md1 == md2)
def test_ne(self):
assert not TMultiDict() != TMultiDict()
assert TMultiDict() != self._multi()
assert TMultiDict() != 42
def test_get_all(self):
md = self._multi()
assert md.get_all("foo") == ["bar"]
assert md.get_all("bar") == ["baz", "bam"]
assert md.get_all("baz") == []
def test_set_all(self):
md = TMultiDict()
md.set_all("foo", ["bar", "baz"])
assert md.fields == (("foo", "bar"), ("foo", "baz"))
md = TMultiDict((
("a", "b"),
("x", "x"),
("c", "d"),
("X", "x"),
("e", "f"),
))
md.set_all("x", ["1", "2", "3"])
assert md.fields == (
("a", "b"),
("x", "1"),
("c", "d"),
("x", "2"),
("e", "f"),
("x", "3"),
)
md.set_all("x", ["4"])
assert md.fields == (
("a", "b"),
("x", "4"),
("c", "d"),
("e", "f"),
)
def test_add(self):
md = self._multi()
md.add("foo", "foo")
assert md.fields == (
("foo", "bar"),
("bar", "baz"),
("Bar", "bam"),
("foo", "foo")
)
def test_insert(self):
md = TMultiDict([("b", "b")])
md.insert(0, "a", "a")
md.insert(2, "c", "c")
assert md.fields == (("a", "a"), ("b", "b"), ("c", "c"))
def test_keys(self):
md = self._multi()
assert list(md.keys()) == ["foo", "bar"]
assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"]
def test_values(self):
md = self._multi()
assert list(md.values()) == ["bar", "baz"]
assert list(md.values(multi=True)) == ["bar", "baz", "bam"]
def test_items(self):
md = self._multi()
assert list(md.items()) == [("foo", "bar"), ("bar", "baz")]
assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")]
def test_to_dict(self):
md = self._multi()
assert md.to_dict() == {
"foo": "bar",
"bar": ["baz", "bam"]
}
def test_state(self):
md = self._multi()
assert len(md.get_state()) == 3
assert md == TMultiDict.from_state(md.get_state())
md2 = TMultiDict()
assert md != md2
md2.set_state(md.get_state())
assert md == md2
class TestImmutableMultiDict(object):
def test_modify(self):
md = TImmutableMultiDict()
with tutils.raises(TypeError):
md["foo"] = "bar"
with tutils.raises(TypeError):
del md["foo"]
with tutils.raises(TypeError):
md.add("foo", "bar")
def test_with_delitem(self):
md = TImmutableMultiDict([("foo", "bar")])
assert md.with_delitem("foo").fields == ()
assert md.fields == (("foo", "bar"),)
def test_with_set_all(self):
md = TImmutableMultiDict()
assert md.with_set_all("foo", ["bar"]).fields == (("foo", "bar"),)
assert md.fields == ()
def test_with_insert(self):
md = TImmutableMultiDict()
assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),)
class TParent(object):
def __init__(self):
self.vals = tuple()
def setter(self, vals):
self.vals = vals
def getter(self):
return self.vals
class TestMultiDictView(object):
def test_modify(self):
p = TParent()
tv = MultiDictView(p.getter, p.setter)
assert len(tv) == 0
tv["a"] = "b"
assert p.vals == (("a", "b"),)
tv["c"] = "b"
assert p.vals == (("a", "b"), ("c", "b"))
assert tv["a"] == "b"