From 44ac64aa7235362acbb96e0f12aa27534580e575 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 18 May 2016 18:46:42 -0700 Subject: [PATCH 1/8] add MultiDict This commit introduces MultiDict, a multi-dictionary similar to ODict, but with improved semantics (as in the Headers class). MultiDict fixes a few issues that were present in the Request/Response API. In particular, `request.cookies["foo"] = "bar"` has previously been a no-op, as the cookies property returned a mutable _copy_ of the cookies. --- examples/modify_form.py | 11 +- examples/modify_querystring.py | 5 +- mitmproxy/flow.py | 8 +- mitmproxy/flow_export.py | 4 +- netlib/encoding.py | 1 - netlib/http/cookies.py | 17 ++- netlib/http/headers.py | 130 +++++----------- netlib/http/http1/read.py | 4 +- netlib/http/http2/connections.py | 12 +- netlib/http/message.py | 35 +++++ netlib/http/request.py | 71 +++++---- netlib/http/response.py | 2 + netlib/multidict.py | 163 +++++++++++++++++++++ netlib/utils.py | 11 -- test/mitmproxy/test_examples.py | 12 +- test/mitmproxy/test_flow.py | 54 ------- test/mitmproxy/test_flow_export.py | 2 +- test/netlib/http/http1/test_read.py | 8 +- test/netlib/http/http2/test_connections.py | 6 +- test/netlib/http/test_cookies.py | 4 +- test/netlib/http/test_headers.py | 10 +- test/netlib/http/test_request.py | 61 ++++---- test/netlib/http/test_response.py | 2 +- 23 files changed, 369 insertions(+), 264 deletions(-) create mode 100644 netlib/multidict.py diff --git a/examples/modify_form.py b/examples/modify_form.py index 86188781b..c4edb2cd7 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,5 +1,8 @@ def request(context, flow): - form = flow.request.urlencoded_form - if form is not None: - form["mitmproxy"] = ["rocks"] - flow.request.urlencoded_form = form + if flow.request.urlencoded_form is not None: + flow.request.urlencoded_form["mitmproxy"] = "rocks" + else: + # This sets the proper content type and overrides the body. + flow.request.urlencoded_form = [ + ("foo", "bar") + ] diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py index d682df697..b89e5c8dc 100644 --- a/examples/modify_querystring.py +++ b/examples/modify_querystring.py @@ -1,5 +1,2 @@ def request(context, flow): - q = flow.request.query - if q: - q["mitmproxy"] = ["rocks"] - flow.request.query = q + flow.request.query["mitmproxy"] = "rocks" diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 7fd97af30..4663144d0 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -156,9 +156,9 @@ class SetHeaders: for _, header, value, cpatt in self.lst: if cpatt(f): if f.response: - f.response.headers.fields.append((header, value)) + f.response.headers.add(header, value) else: - f.request.headers.fields.append((header, value)) + f.request.headers.add(header, value) class StreamLargeBodies(object): @@ -263,7 +263,7 @@ class ServerPlaybackState: form_contents = r.urlencoded_form or r.multipart_form if self.ignore_payload_params and form_contents: key.extend( - p for p in form_contents + p for p in form_contents.items(multi=True) if p[0] not in self.ignore_payload_params ) else: @@ -354,7 +354,7 @@ class StickyCookieState: ] if all(match): c = self.jar[i] - l.extend([cookies.format_cookie_header(c[name]) for name in c.keys()]) + l.extend([cookies.format_cookie_header(c[name].lst) for name in c.keys()]) if l: f.request.stickycookie = True f.request.headers["cookie"] = "; ".join(l) diff --git a/mitmproxy/flow_export.py b/mitmproxy/flow_export.py index d8e65704c..ae282fcea 100644 --- a/mitmproxy/flow_export.py +++ b/mitmproxy/flow_export.py @@ -51,7 +51,7 @@ def python_code(flow): params = "" if flow.request.query: - lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query] + lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()] params = "\nparams = {\n%s}\n" % "".join(lines) args += "\n params=params," @@ -140,7 +140,7 @@ def locust_code(flow): params = "" if flow.request.query: - lines = [" '%s': '%s',\n" % (k, v) for k, v in flow.request.query] + lines = [" %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()] params = "\n params = {\n%s }\n" % "".join(lines) args += "\n params=params," diff --git a/netlib/encoding.py b/netlib/encoding.py index 14479e00c..98502451c 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,7 +5,6 @@ from __future__ import absolute_import from io import BytesIO import gzip import zlib -from .utils import always_byte_args ENCODINGS = {"identity", "gzip", "deflate"} diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 4451f1daf..fd5311469 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,6 +1,4 @@ -from six.moves import http_cookies as Cookie import re -import string from email.utils import parsedate_tz, formatdate, mktime_tz from .. import odict @@ -179,20 +177,27 @@ def format_set_cookie_header(name, value, attrs): return _format_set_cookie_pairs(pairs) +def parse_cookie_headers(cookie_headers): + cookie_list = [] + for header in cookie_headers: + cookie_list.extend(parse_cookie_header(header)) + return cookie_list + + def parse_cookie_header(line): """ Parse a Cookie header value. - Returns a (possibly empty) ODict object. + Returns a list of (lhs, rhs) tuples. """ pairs, off_ = _read_pairs(line) - return odict.ODict(pairs) + return pairs -def format_cookie_header(od): +def format_cookie_header(lst): """ Formats a Cookie header value. """ - return _format_pairs(od.lst) + return _format_pairs(lst) def refresh_set_cookie_header(c, delta): diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 72739f900..7e39c371f 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -1,9 +1,3 @@ -""" - -Unicode Handling ----------------- -See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ -""" from __future__ import absolute_import, print_function, division import re @@ -13,23 +7,22 @@ try: except ImportError: # pragma: no cover from collections import MutableMapping # Workaround for Python < 3.3 - import six +from ..multidict import MultiDict +from ..utils import always_bytes -from netlib.utils import always_byte_args, always_bytes, Serializable +# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ if six.PY2: # pragma: no cover _native = lambda x: x _always_bytes = lambda x: x - _always_byte_args = lambda x: x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. _native = lambda x: x.decode("utf-8", "surrogateescape") _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") - _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping, Serializable): +class Headers(MultiDict): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -49,7 +42,7 @@ class Headers(MutableMapping, Serializable): >>> h["host"] "example.com" - # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples + # Headers can also be created from a list of raw (header_name, header_value) byte tuples >>> h = Headers([ [b"Host",b"example.com"], [b"Accept",b"text/html"], @@ -77,7 +70,6 @@ class Headers(MutableMapping, Serializable): For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ - @_always_byte_args def __init__(self, fields=None, **headers): """ Args: @@ -89,19 +81,25 @@ class Headers(MutableMapping, Serializable): If ``**headers`` contains multiple keys that have equal ``.lower()`` s, the behavior is undefined. """ - self.fields = fields or [] - - for name, value in self.fields: - if not isinstance(name, bytes) or not isinstance(value, bytes): - raise ValueError("Headers passed as fields must be bytes.") + super(Headers, self).__init__(fields) # content_type -> content-type headers = { - _always_bytes(name).replace(b"_", b"-"): value + _always_bytes(name).replace(b"_", b"-"): _always_bytes(value) for name, value in six.iteritems(headers) } self.update(headers) + @staticmethod + def _reduce_values(values): + # Headers can be folded + return ", ".join(values) + + @staticmethod + def _kconv(key): + # Headers are case-insensitive + return key.lower() + def __bytes__(self): if self.fields: return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" @@ -111,98 +109,40 @@ class Headers(MutableMapping, Serializable): if six.PY2: # pragma: no cover __str__ = __bytes__ - @_always_byte_args - def __getitem__(self, name): - values = self.get_all(name) - if not values: - raise KeyError(name) - return ", ".join(values) - - @_always_byte_args - def __setitem__(self, name, value): - idx = self._index(name) - - # To please the human eye, we insert at the same position the first existing header occured. - if idx is not None: - del self[name] - self.fields.insert(idx, [name, value]) - else: - self.fields.append([name, value]) - - @_always_byte_args - def __delitem__(self, name): - if name not in self: - raise KeyError(name) - name = name.lower() - self.fields = [ - field for field in self.fields - if name != field[0].lower() - ] + def __delitem__(self, key): + key = _always_bytes(key) + super(Headers, self).__delitem__(key) def __iter__(self): - seen = set() - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - yield _native(name) + for x in super(Headers, self).__iter__(): + yield _native(x) - def __len__(self): - return len(set(name.lower() for name, _ in self.fields)) - - # __hash__ = object.__hash__ - - def _index(self, name): - name = name.lower() - for i, field in enumerate(self.fields): - if field[0].lower() == name: - return i - return None - - def __eq__(self, other): - if isinstance(other, Headers): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - @_always_byte_args def get_all(self, name): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. - See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 """ - name_lower = name.lower() - values = [_native(value) for n, value in self.fields if n.lower() == name_lower] - return values + name = _always_bytes(name) + return [ + _native(x) for x in + super(Headers, self).get_all(name) + ] - @_always_byte_args def set_all(self, name, values): """ Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ - values = map(_always_bytes, values) # _always_byte_args does not fix lists - if name in self: - del self[name] - self.fields.extend( - [name, value] for value in values - ) + name = _always_bytes(name) + values = [_always_bytes(x) for x in values] + return super(Headers, self).set_all(name, values) - def get_state(self): - return tuple(tuple(field) for field in self.fields) + def insert(self, index, key, value): + key = _always_bytes(key) + value = _always_bytes(value) + super(Headers, self).insert(index, key, value) - def set_state(self, state): - self.fields = [list(field) for field in state] - - @classmethod - def from_state(cls, state): - return cls([list(field) for field in state]) - - @_always_byte_args def replace(self, pattern, repl, flags=0): """ Replaces a regular expression pattern with repl in each "name: value" @@ -211,6 +151,8 @@ class Headers(MutableMapping, Serializable): Returns: The number of replacements made. """ + pattern = _always_bytes(pattern) + repl = _always_bytes(repl) pattern = re.compile(pattern, flags) replacements = 0 diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 6e3a1b93a..d30976bdd 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -316,14 +316,14 @@ def _read_headers(rfile): if not ret: raise HttpSyntaxException("Invalid headers") # continued header - ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() + ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) else: try: name, value = line.split(b":", 1) value = value.strip() if not name: raise ValueError() - ret.append([name, value]) + ret.append((name, value)) except ValueError: raise HttpSyntaxException("Invalid headers") return Headers(ret) diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index f900b67cc..6643b6b91 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -201,13 +201,13 @@ class HTTP2Protocol(object): headers = request.headers.copy() if ':authority' not in headers: - headers.fields.insert(0, (b':authority', authority.encode('ascii'))) + headers.insert(0, b':authority', authority.encode('ascii')) if ':scheme' not in headers: - headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) + headers.insert(0, b':scheme', request.scheme.encode('ascii')) if ':path' not in headers: - headers.fields.insert(0, (b':path', request.path.encode('ascii'))) + headers.insert(0, b':path', request.path.encode('ascii')) if ':method' not in headers: - headers.fields.insert(0, (b':method', request.method.encode('ascii'))) + headers.insert(0, b':method', request.method.encode('ascii')) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -224,7 +224,7 @@ class HTTP2Protocol(object): headers = response.headers.copy() if ':status' not in headers: - headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) + headers.insert(0, b':status', str(response.status_code).encode('ascii')) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -420,7 +420,7 @@ class HTTP2Protocol(object): self._handle_unexpected_frame(frm) headers = Headers( - [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] + (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks) ) return stream_id, headers, body diff --git a/netlib/http/message.py b/netlib/http/message.py index da9681a0b..262ef3e1d 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,6 +4,7 @@ import warnings import six +from ..multidict import MultiDict from .headers import Headers from .. import encoding, utils @@ -235,3 +236,37 @@ class decoded(object): def __exit__(self, type, value, tb): if self.ce: self.message.encode(self.ce) + + +class MessageMultiDict(MultiDict): + """ + A MultiDict that provides a proxy view to the underlying message. + """ + + def __init__(self, attr, message): + if False: + # We do not want to call the parent constructor here as that + # would cause an unnecessary parse/unparse pass. + # This is here to silence linters. Message + super(MessageMultiDict, self).__init__(None) + self._attr = attr + self._message = message # type: Message + + @staticmethod + def _kconv(key): + # All request-attributes are case-sensitive. + return key + + @staticmethod + def _reduce_values(values): + # We just return the first element if + # multiple elements exist with the same key. + return values[0] + + @property + def fields(self): + return getattr(self._message, "_" + self._attr) + + @fields.setter + def fields(self, value): + setattr(self._message, self._attr, value) diff --git a/netlib/http/request.py b/netlib/http/request.py index a42150ff9..26ec12cf6 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -11,7 +11,7 @@ from netlib.http import cookies from netlib.odict import ODict from .. import encoding from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData +from .message import Message, _native, _always_bytes, MessageData, MessageMultiDict # This regex extracts & splits the host header into host and port. # Handles the edge case of IPv6 addresses containing colons. @@ -224,45 +224,54 @@ class Request(Message): @property def query(self): + # type: () -> MessageMultiDict """ - The request query string as an :py:class:`ODict` object. - None, if there is no query. + The request query string as an :py:class:`MessageMultiDict` object. """ + return MessageMultiDict("query", self) + + @property + def _query(self): _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return None + return tuple(utils.urldecode(query)) @query.setter - def query(self, odict): - query = utils.urlencode(odict.lst) + def query(self, value): + query = utils.urlencode(value) scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) _, _, _, self.path = utils.parse_url( urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) @property def cookies(self): + # type: () -> MessageMultiDict """ The request cookies. - An empty :py:class:`ODict` object if the cookie monster ate them all. + + An empty :py:class:`MessageMultiDict` object if the cookie monster ate them all. """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret + return MessageMultiDict("cookies", self) + + @property + def _cookies(self): + h = self.headers.get_all("Cookie") + return tuple(cookies.parse_cookie_headers(h)) @cookies.setter - def cookies(self, odict): - self.headers["cookie"] = cookies.format_cookie_header(odict) + def cookies(self, value): + self.headers["cookie"] = cookies.format_cookie_header(value) @property def path_components(self): """ - The URL's path components as a list of strings. + The URL's path components as a tuple of strings. Components are unquoted. """ _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split("/") if i] + # This needs to be a tuple so that it's immutable. + # Otherwise, this would fail silently: + # request.path_components.append("foo") + return tuple(urllib.parse.unquote(i) for i in path.split("/") if i) @path_components.setter def path_components(self, components): @@ -309,34 +318,42 @@ class Request(Message): @property def urlencoded_form(self): """ - The URL-encoded form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. + The URL-encoded form data as an :py:class:`MessageMultiDict` object. + None if the content-type indicates non-form data. """ is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.urldecode(self.content)) + if is_valid_content_type: + return MessageMultiDict("urlencoded_form", self) return None + @property + def _urlencoded_form(self): + return tuple(utils.urldecode(self.content)) + @urlencoded_form.setter - def urlencoded_form(self, odict): + def urlencoded_form(self, value): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = utils.urlencode(odict.lst) + self.content = utils.urlencode(value) @property def multipart_form(self): """ - The multipart form data as an :py:class:`ODict` object. - None if there is no data or the content-type indicates non-form data. + The multipart form data as an :py:class:`MultipartFormDict` object. + None if the content-type indicates non-form data. """ is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() - if self.content and is_valid_content_type: - return ODict(utils.multipartdecode(self.headers,self.content)) + if is_valid_content_type: + return MessageMultiDict("multipart_form", self) return None + @property + def _multipart_form(self): + return utils.multipartdecode(self.headers, self.content) + @multipart_form.setter def multipart_form(self, value): raise NotImplementedError() diff --git a/netlib/http/response.py b/netlib/http/response.py index 2f06149e3..20074dca2 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -70,6 +70,7 @@ class Response(Message): def reason(self, reason): self.data.reason = _always_bytes(reason) + # FIXME @property def cookies(self): """ @@ -88,6 +89,7 @@ class Response(Message): ret.append([name, [value, attrs]]) return ODict(ret) + # FIXME @cookies.setter def cookies(self, odict): values = [] diff --git a/netlib/multidict.py b/netlib/multidict.py new file mode 100644 index 000000000..a7158bc5c --- /dev/null +++ b/netlib/multidict.py @@ -0,0 +1,163 @@ +from __future__ import absolute_import, print_function, division + +from abc import ABCMeta, abstractmethod + +from typing import Tuple + +try: + from collections.abc import MutableMapping +except ImportError: # pragma: no cover + from collections import MutableMapping # Workaround for Python < 3.3 + +import six + +from .utils import Serializable + + +@six.add_metaclass(ABCMeta) +class MultiDict(MutableMapping, Serializable): + def __init__(self, fields=None): + + # it is important for us that .fields is immutable, so that we can easily + # detect changes to it. + self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] + + for key, value in self.fields: + if not isinstance(key, bytes) or not isinstance(value, bytes): + raise TypeError("MultiDict fields must be bytes.") + + def __repr__(self): + fields = tuple( + repr(field) + for field in self.fields + ) + return "{cls}[{fields}]".format( + cls=type(self).__name__, + fields=", ".join(fields) + ) + + @staticmethod + @abstractmethod + def _reduce_values(values): + pass + + @staticmethod + @abstractmethod + def _kconv(v): + pass + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + return self._reduce_values(values) + + def __setitem__(self, key, value): + self.set_all(key, [value]) + + def __delitem__(self, key): + if key not in self: + raise KeyError(key) + key = self._kconv(key) + self.fields = tuple( + field for field in self.fields + if key != self._kconv(field[0]) + ) + + def __iter__(self): + seen = set() + for key, _ in self.fields: + key_kconv = self._kconv(key) + if key_kconv not in seen: + seen.add(key_kconv) + yield key + + def __len__(self): + return len(set(self._kconv(key) for key, _ in self.fields)) + + def __eq__(self, other): + if isinstance(other, MultiDict): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, key): + """ + Return the list of items for a given key. + If that key is not in the MultiDict, + the return value will be an empty list. + """ + key = self._kconv(key) + return [ + value + for k, value in self.fields + if self._kconv(k) == key + ] + + def set_all(self, key, values): + """ + Remove the old values for a key and add new ones. + """ + key_kconv = self._kconv(key) + + new_fields = [] + for field in self.fields: + if self._kconv(field[0]) == key_kconv: + if values: + new_fields.append( + (key, values.pop(0)) + ) + else: + new_fields.append(field) + while values: + new_fields.append( + (key, values.pop(0)) + ) + self.fields = tuple(new_fields) + + def add(self, key, value): + self.insert(len(self.fields), key, value) + + def insert(self, index, key, value): + item = (key, value) + self.fields = self.fields[:index] + (item,) + self.fields[index:] + + def keys(self, multi=False): + return ( + k + for k, _ in self.items(multi) + ) + + def values(self, multi=False): + return ( + v + for _, v in self.items(multi) + ) + + def items(self, multi=False): + if multi: + return self.fields + else: + return super(MultiDict, self).items() + + def to_dict(self): + d = {} + for key in self: + values = self.get_all(key) + if len(values) == 1: + d[key] = values[0] + else: + d[key] = values + return d + + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(tuple(x) for x in state) diff --git a/netlib/utils.py b/netlib/utils.py index be2701a07..7499f71fc 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -51,17 +51,6 @@ def always_bytes(unicode_or_bytes, *encode_args): return unicode_or_bytes -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator - - def native(s, *encoding_opts): """ Convert :py:class:`bytes` or :py:class:`unicode` to the native diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index c401a6b93..d0a258e9c 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -94,14 +94,22 @@ def test_modify_form(): flow = tutils.tflow(req=netutils.treq(headers=form_header)) with example("modify_form.py") as ex: ex.run("request", flow) - assert flow.request.urlencoded_form["mitmproxy"] == ["rocks"] + assert flow.request.urlencoded_form["mitmproxy"] == "rocks" + + flow.request.headers["content-type"] = "" + ex.run("request", flow) + assert list(flow.request.urlencoded_form.items()) == [("foo","bar")] def test_modify_querystring(): flow = tutils.tflow(req=netutils.treq(path="/search?q=term")) with example("modify_querystring.py") as ex: ex.run("request", flow) - assert flow.request.query["mitmproxy"] == ["rocks"] + assert flow.request.query["mitmproxy"] == "rocks" + + flow.request.path = "/" + ex.run("request", flow) + assert flow.request.query["mitmproxy"] == "rocks" def test_modify_response_body(): diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index b9c6a2f64..bf417423b 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -1067,60 +1067,6 @@ class TestRequest: assert r.url == "https://address:22/path" assert r.pretty_url == "https://foo.com:22/path" - def test_path_components(self): - r = HTTPRequest.wrap(netlib.tutils.treq()) - r.path = "/" - assert r.get_path_components() == [] - r.path = "/foo/bar" - assert r.get_path_components() == ["foo", "bar"] - q = odict.ODict() - q["test"] = ["123"] - r.set_query(q) - assert r.get_path_components() == ["foo", "bar"] - - r.set_path_components([]) - assert r.get_path_components() == [] - r.set_path_components(["foo"]) - assert r.get_path_components() == ["foo"] - r.set_path_components(["/oo"]) - assert r.get_path_components() == ["/oo"] - assert "%2F" in r.path - - def test_getset_form_urlencoded(self): - d = odict.ODict([("one", "two"), ("three", "four")]) - r = HTTPRequest.wrap(netlib.tutils.treq(content=netlib.utils.urlencode(d.lst))) - r.headers["content-type"] = "application/x-www-form-urlencoded" - assert r.get_form_urlencoded() == d - - d = odict.ODict([("x", "y")]) - r.set_form_urlencoded(d) - assert r.get_form_urlencoded() == d - - r.headers["content-type"] = "foo" - assert not r.get_form_urlencoded() - - def test_getset_query(self): - r = HTTPRequest.wrap(netlib.tutils.treq()) - r.path = "/foo?x=y&a=b" - q = r.get_query() - assert q.lst == [("x", "y"), ("a", "b")] - - r.path = "/" - q = r.get_query() - assert not q - - r.path = "/?adsfa" - q = r.get_query() - assert q.lst == [("adsfa", "")] - - r.path = "/foo?x=y&a=b" - assert r.get_query() - r.set_query(odict.ODict([])) - assert not r.get_query() - qv = odict.ODict([("a", "b"), ("c", "d")]) - r.set_query(qv) - assert r.get_query() == qv - def test_anticache(self): r = HTTPRequest.wrap(netlib.tutils.treq()) r.headers = Headers() diff --git a/test/mitmproxy/test_flow_export.py b/test/mitmproxy/test_flow_export.py index 035f07b72..2b1f897cb 100644 --- a/test/mitmproxy/test_flow_export.py +++ b/test/mitmproxy/test_flow_export.py @@ -21,7 +21,7 @@ def python_equals(testdata, text): assert clean_blanks(text).rstrip() == clean_blanks(d).rstrip() -req_get = lambda: netlib.tutils.treq(method='GET', content='') +req_get = lambda: netlib.tutils.treq(method='GET', content='', path=b"/") req_post = lambda: netlib.tutils.treq(method='POST', headers=None) diff --git a/test/netlib/http/http1/test_read.py b/test/netlib/http/http1/test_read.py index 902340702..d81069049 100644 --- a/test/netlib/http/http1/test_read.py +++ b/test/netlib/http/http1/test_read.py @@ -261,7 +261,7 @@ class TestReadHeaders(object): b"\r\n" ) headers = self._read(data) - assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]] + assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two")) def test_read_multi(self): data = ( @@ -270,7 +270,7 @@ class TestReadHeaders(object): b"\r\n" ) headers = self._read(data) - assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]] + assert headers.fields == ((b"Header", b"one"), (b"Header", b"two")) def test_read_continued(self): data = ( @@ -280,7 +280,7 @@ class TestReadHeaders(object): b"\r\n" ) headers = self._read(data) - assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]] + assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three")) def test_read_continued_err(self): data = b"\tfoo: bar\r\n" @@ -300,7 +300,7 @@ class TestReadHeaders(object): def test_read_empty_value(self): data = b"bar:" headers = self._read(data) - assert headers.fields == [[b"bar", b""]] + assert headers.fields == ((b"bar", b""),) def test_read_chunked(): req = treq(content=None) diff --git a/test/netlib/http/http2/test_connections.py b/test/netlib/http/http2/test_connections.py index 7b003067a..7d240c0e3 100644 --- a/test/netlib/http/http2/test_connections.py +++ b/test/netlib/http/http2/test_connections.py @@ -312,7 +312,7 @@ class TestReadRequest(tservers.ServerTestBase): req = protocol.read_request(NotImplemented) assert req.stream_id - assert req.headers.fields == [[b':method', b'GET'], [b':path', b'/'], [b':scheme', b'https']] + assert req.headers.fields == ((b':method', b'GET'), (b':path', b'/'), (b':scheme', b'https')) assert req.content == b'foobar' @@ -418,7 +418,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.http_version == "HTTP/2.0" assert resp.status_code == 200 assert resp.reason == '' - assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] + assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) assert resp.content == b'foobar' assert resp.timestamp_end @@ -445,7 +445,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.http_version == "HTTP/2.0" assert resp.status_code == 200 assert resp.reason == '' - assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] + assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar')) assert resp.content == b'' diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py index da28850f2..e2cee57ff 100644 --- a/test/netlib/http/test_cookies.py +++ b/test/netlib/http/test_cookies.py @@ -128,10 +128,10 @@ def test_cookie_roundtrips(): ] for s, lst in pairs: ret = cookies.parse_cookie_header(s) - assert ret.lst == lst + assert ret == lst s2 = cookies.format_cookie_header(ret) ret = cookies.parse_cookie_header(s2) - assert ret.lst == lst + assert ret == lst def test_parse_set_cookie_pairs(): diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 8c1db9dc5..48d3b3233 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -5,10 +5,10 @@ from netlib.tutils import raises class TestHeaders(object): def _2host(self): return Headers( - [ - [b"Host", b"example.com"], - [b"host", b"example.org"] - ] + ( + (b"Host", b"example.com"), + (b"host", b"example.org") + ) ) def test_init(self): @@ -38,7 +38,7 @@ class TestHeaders(object): assert headers["Host"] == "example.com" assert headers["Accept"] == "text/plain" - with raises(ValueError): + with raises(TypeError): Headers([[b"Host", u"not-bytes"]]) def test_getitem(self): diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index 7ed6bd0f0..26593ee12 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -12,7 +12,7 @@ from .test_message import _test_decoded_attr, _test_passthrough_attr class TestRequestData(object): def test_init(self): - with raises(ValueError if six.PY2 else TypeError): + with raises(ValueError): treq(headers="foobar") assert isinstance(treq(headers=None).headers, Headers) @@ -158,16 +158,17 @@ class TestRequestUtils(object): def test_get_query(self): request = treq() - assert request.query is None + assert not request.query request.url = "http://localhost:80/foo?bar=42" - assert request.query.lst == [("bar", "42")] + assert dict(request.query) == {"bar": "42"} def test_set_query(self): - request = treq(host=b"foo", headers = Headers(host=b"bar")) - request.query = ODict([]) - assert request.host == "foo" - assert request.headers["host"] == "bar" + request = treq() + assert not request.query + request.query["foo"] = "bar" + assert request.query["foo"] == "bar" + assert request.path == "/path?foo=bar" def test_get_cookies_none(self): request = treq() @@ -177,47 +178,50 @@ class TestRequestUtils(object): def test_get_cookies_single(self): request = treq() request.headers = Headers(cookie="cookiename=cookievalue") - result = request.cookies - assert len(result) == 1 - assert result['cookiename'] == ['cookievalue'] + assert len(request.cookies) == 1 + assert request.cookies['cookiename'] == 'cookievalue' def test_get_cookies_double(self): request = treq() request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") result = request.cookies assert len(result) == 2 - assert result['cookiename'] == ['cookievalue'] - assert result['othercookiename'] == ['othercookievalue'] + assert result['cookiename'] == 'cookievalue' + assert result['othercookiename'] == 'othercookievalue' def test_get_cookies_withequalsign(self): request = treq() request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") result = request.cookies assert len(result) == 2 - assert result['cookiename'] == ['coo=kievalue'] - assert result['othercookiename'] == ['othercookievalue'] + assert result['cookiename'] == 'coo=kievalue' + assert result['othercookiename'] == 'othercookievalue' def test_set_cookies(self): request = treq() request.headers = Headers(cookie="cookiename=cookievalue") result = request.cookies - result["cookiename"] = ["foo"] - request.cookies = result - assert request.cookies["cookiename"] == ["foo"] + result["cookiename"] = "foo" + assert request.cookies["cookiename"] == "foo" def test_get_path_components(self): request = treq(path=b"/foo/bar") - assert request.path_components == ["foo", "bar"] + assert request.path_components == ("foo", "bar") def test_set_path_components(self): - request = treq(host=b"foo", headers = Headers(host=b"bar")) + request = treq() request.path_components = ["foo", "baz"] assert request.path == "/foo/baz" + request.path_components = [] assert request.path == "/" - request.query = ODict([]) - assert request.host == "foo" - assert request.headers["host"] == "bar" + + request.path_components = ["foo", "baz"] + request.query["hello"] = "hello" + assert request.path_components == ("foo", "baz") + + request.path_components = ["abc"] + assert request.path == "/abc?hello=hello" def test_anticache(self): request = treq() @@ -246,15 +250,15 @@ class TestRequestUtils(object): assert "gzip" in request.headers["Accept-Encoding"] def test_get_urlencoded_form(self): - request = treq(content="foobar") + request = treq(content="foobar=baz") assert request.urlencoded_form is None request.headers["Content-Type"] = "application/x-www-form-urlencoded" - assert request.urlencoded_form == ODict(utils.urldecode(request.content)) + assert list(request.urlencoded_form.items()) == [("foobar", "baz")] def test_set_urlencoded_form(self): request = treq() - request.urlencoded_form = ODict([('foo', 'bar'), ('rab', 'oof')]) + request.urlencoded_form = [('foo', 'bar'), ('rab', 'oof')] assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" assert request.content @@ -263,9 +267,4 @@ class TestRequestUtils(object): assert request.multipart_form is None request.headers["Content-Type"] = "multipart/form-data" - assert request.multipart_form == ODict( - utils.multipartdecode( - request.headers, - request.content - ) - ) + assert list(request.multipart_form.items()) == [] diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py index 5440176c3..37273541a 100644 --- a/test/netlib/http/test_response.py +++ b/test/netlib/http/test_response.py @@ -13,7 +13,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr class TestResponseData(object): def test_init(self): - with raises(ValueError if six.PY2 else TypeError): + with raises(ValueError): tresp(headers="foobar") assert isinstance(tresp(headers=None).headers, Headers) From 8e39b7bf38e7becd1116dfcded380327fd0228d0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 18 May 2016 19:28:23 -0700 Subject: [PATCH 2/8] test flow export with duplicate query string --- test/mitmproxy/test_flow_export.py | 6 +++--- test/mitmproxy/test_flow_export/locust_get.py | 6 ++++++ test/mitmproxy/test_flow_export/locust_task_get.py | 6 ++++++ test/mitmproxy/test_flow_export/python_get.py | 6 ++++++ 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/mitmproxy/test_flow_export.py b/test/mitmproxy/test_flow_export.py index 2b1f897cb..c252c5bd3 100644 --- a/test/mitmproxy/test_flow_export.py +++ b/test/mitmproxy/test_flow_export.py @@ -21,7 +21,7 @@ def python_equals(testdata, text): assert clean_blanks(text).rstrip() == clean_blanks(d).rstrip() -req_get = lambda: netlib.tutils.treq(method='GET', content='', path=b"/") +req_get = lambda: netlib.tutils.treq(method='GET', content='', path=b"/path?a=foo&a=bar&b=baz") req_post = lambda: netlib.tutils.treq(method='POST', headers=None) @@ -31,7 +31,7 @@ req_patch = lambda: netlib.tutils.treq(method='PATCH', path=b"/path?query=param" class TestExportCurlCommand(): def test_get(self): flow = tutils.tflow(req=req_get()) - result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path'""" + result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path?a=foo&a=bar&b=baz'""" assert flow_export.curl_command(flow) == result def test_post(self): @@ -70,7 +70,7 @@ class TestRawRequest(): def test_get(self): flow = tutils.tflow(req=req_get()) result = dedent(""" - GET /path HTTP/1.1\r + GET /path?a=foo&a=bar&b=baz HTTP/1.1\r header: qvalue\r content-length: 7\r host: address:22\r diff --git a/test/mitmproxy/test_flow_export/locust_get.py b/test/mitmproxy/test_flow_export/locust_get.py index 72d5932aa..632d5d537 100644 --- a/test/mitmproxy/test_flow_export/locust_get.py +++ b/test/mitmproxy/test_flow_export/locust_get.py @@ -14,10 +14,16 @@ class UserBehavior(TaskSet): 'content-length': '7', } + params = { + 'a': ['foo', 'bar'], + 'b': 'baz', + } + self.response = self.client.request( method='GET', url=url, headers=headers, + params=params, ) ### Additional tasks can go here ### diff --git a/test/mitmproxy/test_flow_export/locust_task_get.py b/test/mitmproxy/test_flow_export/locust_task_get.py index 76f144fa4..03821cd86 100644 --- a/test/mitmproxy/test_flow_export/locust_task_get.py +++ b/test/mitmproxy/test_flow_export/locust_task_get.py @@ -7,8 +7,14 @@ 'content-length': '7', } + params = { + 'a': ['foo', 'bar'], + 'b': 'baz', + } + self.response = self.client.request( method='GET', url=url, headers=headers, + params=params, ) diff --git a/test/mitmproxy/test_flow_export/python_get.py b/test/mitmproxy/test_flow_export/python_get.py index ee3f48ebf..af8f7c81b 100644 --- a/test/mitmproxy/test_flow_export/python_get.py +++ b/test/mitmproxy/test_flow_export/python_get.py @@ -7,10 +7,16 @@ headers = { 'content-length': '7', } +params = { + 'a': ['foo', 'bar'], + 'b': 'baz', +} + response = requests.request( method='GET', url=url, headers=headers, + params=params, ) print(response.text) From 6f8db2d7eb32684a8328e0ae8bdd73eceb861707 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 18 May 2016 22:50:19 -0700 Subject: [PATCH 3/8] improve MultiDict, add ImmutableMultiDict, adjust response.cookies --- examples/modify_form.py | 2 +- mitmproxy/console/flowview.py | 50 ++-- mitmproxy/console/grideditor.py | 12 +- mitmproxy/flow.py | 28 +-- netlib/http/__init__.py | 8 +- netlib/http/cookies.py | 43 +++- netlib/http/headers.py | 4 + netlib/http/message.py | 41 ++- netlib/http/request.py | 69 ++--- netlib/http/response.py | 43 ++-- netlib/multidict.py | 403 ++++++++++++++++++------------ test/mitmproxy/test_examples.py | 2 +- test/netlib/http/test_cookies.py | 14 +- test/netlib/http/test_request.py | 4 +- test/netlib/http/test_response.py | 32 ++- 15 files changed, 432 insertions(+), 323 deletions(-) diff --git a/examples/modify_form.py b/examples/modify_form.py index c4edb2cd7..3fe0cf964 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,5 +1,5 @@ def request(context, flow): - if flow.request.urlencoded_form is not None: + if flow.request.urlencoded_form: flow.request.urlencoded_form["mitmproxy"] = "rocks" else: # This sets the proper content type and overrides the body. diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py index b2ebe49eb..3538c4ad9 100644 --- a/mitmproxy/console/flowview.py +++ b/mitmproxy/console/flowview.py @@ -6,8 +6,7 @@ import sys import math import urwid -from netlib import odict -from netlib.http import Headers +from netlib.http import Headers, status_codes from . import common, grideditor, signals, searchable, tabs from . import flowdetailview from .. import utils, controller, contentviews @@ -316,21 +315,18 @@ class FlowView(tabs.Tabs): return "Invalid URL." signals.flow_change.send(self, flow = self.flow) - def set_resp_code(self, code): - response = self.flow.response + def set_resp_status_code(self, status_code): try: - response.status_code = int(code) + status_code = int(status_code) except ValueError: return None - import BaseHTTPServer - if int(code) in BaseHTTPServer.BaseHTTPRequestHandler.responses: - response.msg = BaseHTTPServer.BaseHTTPRequestHandler.responses[ - int(code)][0] + self.flow.response.status_code = status_code + if status_code in status_codes.RESPONSES: + self.flow.response.reason = status_codes.RESPONSES[status_code] signals.flow_change.send(self, flow = self.flow) - def set_resp_msg(self, msg): - response = self.flow.response - response.msg = msg + def set_resp_reason(self, reason): + self.flow.response.reason = reason signals.flow_change.send(self, flow = self.flow) def set_headers(self, fields, conn): @@ -338,22 +334,22 @@ class FlowView(tabs.Tabs): signals.flow_change.send(self, flow = self.flow) def set_query(self, lst, conn): - conn.set_query(odict.ODict(lst)) + conn.query = lst signals.flow_change.send(self, flow = self.flow) def set_path_components(self, lst, conn): - conn.set_path_components(lst) + conn.path_components = lst signals.flow_change.send(self, flow = self.flow) def set_form(self, lst, conn): - conn.set_form_urlencoded(odict.ODict(lst)) + conn.urlencoded_form = lst signals.flow_change.send(self, flow = self.flow) def edit_form(self, conn): self.master.view_grideditor( grideditor.URLEncodedFormEditor( self.master, - conn.get_form_urlencoded().lst, + conn.urlencoded_form.items(multi=True), self.set_form, conn ) @@ -364,7 +360,7 @@ class FlowView(tabs.Tabs): self.edit_form(conn) def set_cookies(self, lst, conn): - conn.cookies = odict.ODict(lst) + conn.cookies = lst signals.flow_change.send(self, flow = self.flow) def set_setcookies(self, data, conn): @@ -388,7 +384,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.CookieEditor( self.master, - message.cookies.lst, + message.cookies.items(multi=True), self.set_cookies, message ) @@ -397,7 +393,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.SetCookieEditor( self.master, - message.cookies, + message.cookies.items(multi=True), self.set_setcookies, message ) @@ -413,7 +409,7 @@ class FlowView(tabs.Tabs): c = self.master.spawn_editor(message.content or "") message.content = c.rstrip("\n") elif part == "f": - if not message.get_form_urlencoded() and message.content: + if not message.urlencoded_form and message.content: signals.status_prompt_onekey.send( prompt = "Existing body is not a URL-encoded form. Clear and edit?", keys = [ @@ -435,7 +431,7 @@ class FlowView(tabs.Tabs): ) ) elif part == "p": - p = message.get_path_components() + p = message.path_components self.master.view_grideditor( grideditor.PathEditor( self.master, @@ -448,7 +444,7 @@ class FlowView(tabs.Tabs): self.master.view_grideditor( grideditor.QueryEditor( self.master, - message.get_query().lst, + message.query.items(multi=True), self.set_query, message ) ) @@ -458,7 +454,7 @@ class FlowView(tabs.Tabs): text = message.url, callback = self.set_url ) - elif part == "m": + elif part == "m" and message == self.flow.request: signals.status_prompt_onekey.send( prompt = "Method", keys = common.METHOD_OPTIONS, @@ -468,13 +464,13 @@ class FlowView(tabs.Tabs): signals.status_prompt.send( prompt = "Code", text = str(message.status_code), - callback = self.set_resp_code + callback = self.set_resp_status_code ) - elif part == "m": + elif part == "m" and message == self.flow.response: signals.status_prompt.send( prompt = "Message", - text = message.msg, - callback = self.set_resp_msg + text = message.reason, + callback = self.set_resp_reason ) signals.flow_change.send(self, flow = self.flow) diff --git a/mitmproxy/console/grideditor.py b/mitmproxy/console/grideditor.py index 46ff348e9..11ce7d02d 100644 --- a/mitmproxy/console/grideditor.py +++ b/mitmproxy/console/grideditor.py @@ -700,17 +700,17 @@ class SetCookieEditor(GridEditor): def data_in(self, data): flattened = [] - for k, v in data.items(): - flattened.append([k, v[0], v[1].lst]) + for key, (value, attrs) in data: + flattened.append([key, value, attrs.items(multi=True)]) return flattened def data_out(self, data): vals = [] - for i in data: + for key, value, attrs in data: vals.append( [ - i[0], - [i[1], odict.ODictCaseless(i[2])] + key, + (value, attrs) ] ) - return odict.ODict(vals) + return vals diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 4663144d0..647ebf687 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -319,10 +319,10 @@ class StickyCookieState: """ domain = f.request.host path = "/" - if attrs["domain"]: - domain = attrs["domain"][-1] - if attrs["path"]: - path = attrs["path"][-1] + if "domain" in attrs: + domain = attrs["domain"] + if "path" in attrs: + path = attrs["path"] return (domain, f.request.port, path) def domain_match(self, a, b): @@ -333,28 +333,26 @@ class StickyCookieState: return False def handle_response(self, f): - for i in f.response.headers.get_all("set-cookie"): + for name, (value, attrs) in f.response.cookies.items(multi=True): # FIXME: We now know that Cookie.py screws up some cookies with # valid RFC 822/1123 datetime specifications for expiry. Sigh. - name, value, attrs = cookies.parse_set_cookie_header(str(i)) a = self.ckey(attrs, f) if self.domain_match(f.request.host, a[0]): - b = attrs.lst - b.insert(0, [name, value]) - self.jar[a][name] = odict.ODictCaseless(b) + b = attrs.with_insert(0, name, value) + self.jar[a][name] = b def handle_request(self, f): l = [] if f.match(self.flt): - for i in self.jar.keys(): + for domain, port, path in self.jar.keys(): match = [ - self.domain_match(f.request.host, i[0]), - f.request.port == i[1], - f.request.path.startswith(i[2]) + self.domain_match(f.request.host, domain), + f.request.port == port, + f.request.path.startswith(path) ] if all(match): - c = self.jar[i] - l.extend([cookies.format_cookie_header(c[name].lst) for name in c.keys()]) + c = self.jar[(domain, port, path)] + l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()]) if l: f.request.stickycookie = True f.request.headers["cookie"] = "; ".join(l) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 917080f7a..9fafa28fc 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -2,13 +2,13 @@ from __future__ import absolute_import, print_function, division from .request import Request from .response import Response from .headers import Headers -from .message import decoded -from . import http1, http2 +from .message import MultiDictView, decoded +from . import http1, http2, status_codes __all__ = [ "Request", "Response", "Headers", - "decoded", - "http1", "http2", + "MultiDictView", "decoded", + "http1", "http2", "status_codes", ] diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index fd5311469..c5ac45918 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,6 +1,8 @@ +import collections import re from email.utils import parsedate_tz, formatdate, mktime_tz +from netlib.multidict import ImmutableMultiDict from .. import odict """ @@ -155,25 +157,52 @@ def _parse_set_cookie_pairs(s): return pairs +def parse_set_cookie_headers(headers): + ret = [] + for header in headers: + v = parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append((name, SetCookie(value, attrs))) + return ret + + +class CookieAttrs(ImmutableMultiDict): + @staticmethod + def _kconv(v): + return v.lower() + + @staticmethod + def _reduce_values(values): + # See the StickyCookieTest for a weird cookie that only makes sense + # if we take the last part. + return values[-1] + + +SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"]) + + def parse_set_cookie_header(line): """ Parse a Set-Cookie header value Returns a (name, value, attrs) tuple, or None, where attrs is an - ODictCaseless set of attributes. No attempt is made to parse attribute + CookieAttrs dict of attributes. No attempt is made to parse attribute values - they are treated purely as strings. """ pairs = _parse_set_cookie_pairs(line) if pairs: - return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + return pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:]) def format_set_cookie_header(name, value, attrs): """ Formats a Set-Cookie header value. """ - pairs = [[name, value]] - pairs.extend(attrs.lst) + pairs = [(name, value)] + pairs.extend( + attrs.fields if hasattr(attrs, "fields") else attrs + ) return _format_set_cookie_pairs(pairs) @@ -214,10 +243,10 @@ def refresh_set_cookie_header(c, delta): raise ValueError("Invalid Cookie") if "expires" in attrs: - e = parsedate_tz(attrs["expires"][-1]) + e = parsedate_tz(attrs["expires"]) if e: f = mktime_tz(e) + delta - attrs["expires"] = [formatdate(f)] + attrs = attrs.with_set_all("expires", [formatdate(f)]) else: # This can happen when the expires tag is invalid. # reddit.com sends a an expires tag like this: "Thu, 31 Dec @@ -225,7 +254,7 @@ def refresh_set_cookie_header(c, delta): # strictly correct according to the cookie spec. Browsers # appear to parse this tolerantly - maybe we should too. # For now, we just ignore this. - del attrs["expires"] + attrs = attrs.with_delitem("expires") ret = format_set_cookie_header(name, value, attrs) if not ret: diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 7e39c371f..8959394c8 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -83,6 +83,10 @@ class Headers(MultiDict): """ super(Headers, self).__init__(fields) + for key, value in self.fields: + if not isinstance(key, bytes) or not isinstance(value, bytes): + raise TypeError("Header fields must be bytes.") + # content_type -> content-type headers = { _always_bytes(name).replace(b"_", b"-"): _always_bytes(value) diff --git a/netlib/http/message.py b/netlib/http/message.py index 262ef3e1d..3c731ea6e 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -238,9 +238,44 @@ class decoded(object): self.message.encode(self.ce) -class MessageMultiDict(MultiDict): +class MultiDictView(MultiDict): """ - A MultiDict that provides a proxy view to the underlying message. + Some parts in HTTP (Cookies, URL query strings, ...) require a specific data structure: A MultiDict. + It behaves mostly like an ordered dict but it can have several values for the same key. + + The MultiDictView provides a MultiDict *view* on an :py:class:`Request` or :py:class:`Response`. + That is, it represents a part of the request as a MultiDict, but doesn't contain state/data themselves. + + For example, ``request.cookies`` provides a view on the ``Cookie: ...`` header. + Any change to ``request.cookies`` will also modify the ``Cookie`` header. + Any change to the ``Cookie`` header will also modify ``request.cookies``. + + Example: + + .. code-block:: python + + # Cookies are represented as a MultiDict. + >>> request.cookies + MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + + # MultiDicts mostly behave like a normal dict. + >>> request.cookies["name"] + "value" + + # If there is more than one value, only the first value is returned. + >>> request.cookies["a"] + "false" + + # `.get_all(key)` returns a list of all values. + >>> request.cookies.get_all("a") + ["false", "42"] + + # Changes to the headers are immediately reflected in the cookies. + >>> request.cookies + MultiDictView[("name", "value"), ...] + >>> del request.headers["Cookie"] + >>> request.cookies + MultiDictView[] # empty now """ def __init__(self, attr, message): @@ -248,7 +283,7 @@ class MessageMultiDict(MultiDict): # We do not want to call the parent constructor here as that # would cause an unnecessary parse/unparse pass. # This is here to silence linters. Message - super(MessageMultiDict, self).__init__(None) + super(MultiDictView, self).__init__(None) self._attr = attr self._message = message # type: Message diff --git a/netlib/http/request.py b/netlib/http/request.py index 26ec12cf6..ae28084b5 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -11,7 +11,7 @@ from netlib.http import cookies from netlib.odict import ODict from .. import encoding from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData, MessageMultiDict +from .message import Message, _native, _always_bytes, MessageData, MultiDictView # This regex extracts & splits the host header into host and port. # Handles the edge case of IPv6 addresses containing colons. @@ -224,11 +224,11 @@ class Request(Message): @property def query(self): - # type: () -> MessageMultiDict + # type: () -> MultiDictView """ - The request query string as an :py:class:`MessageMultiDict` object. + The request query string as an :py:class:`MultiDictView` object. """ - return MessageMultiDict("query", self) + return MultiDictView("query", self) @property def _query(self): @@ -244,13 +244,13 @@ class Request(Message): @property def cookies(self): - # type: () -> MessageMultiDict + # type: () -> MultiDictView """ The request cookies. - An empty :py:class:`MessageMultiDict` object if the cookie monster ate them all. + An empty :py:class:`MultiDictView` object if the cookie monster ate them all. """ - return MessageMultiDict("cookies", self) + return MultiDictView("cookies", self) @property def _cookies(self): @@ -318,17 +318,18 @@ class Request(Message): @property def urlencoded_form(self): """ - The URL-encoded form data as an :py:class:`MessageMultiDict` object. - None if the content-type indicates non-form data. + The URL-encoded form data as an :py:class:`MultiDictView` object. + An empty MultiDictView if the content-type indicates non-form data + or the content could not be parsed. """ - is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() - if is_valid_content_type: - return MessageMultiDict("urlencoded_form", self) - return None + return MultiDictView("urlencoded_form", self) @property def _urlencoded_form(self): - return tuple(utils.urldecode(self.content)) + is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() + if is_valid_content_type: + return tuple(utils.urldecode(self.content)) + return () @urlencoded_form.setter def urlencoded_form(self, value): @@ -345,45 +346,15 @@ class Request(Message): The multipart form data as an :py:class:`MultipartFormDict` object. None if the content-type indicates non-form data. """ - is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() - if is_valid_content_type: - return MessageMultiDict("multipart_form", self) - return None + return MultiDictView("multipart_form", self) @property def _multipart_form(self): - return utils.multipartdecode(self.headers, self.content) + is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() + if is_valid_content_type: + return utils.multipartdecode(self.headers, self.content) + return () @multipart_form.setter def multipart_form(self, value): raise NotImplementedError() - - # Legacy - - def get_query(self): # pragma: no cover - warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) - return self.query or ODict([]) - - def set_query(self, odict): # pragma: no cover - warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) - self.query = odict - - def get_path_components(self): # pragma: no cover - warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) - return self.path_components - - def set_path_components(self, lst): # pragma: no cover - warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) - self.path_components = lst - - def get_form_urlencoded(self): # pragma: no cover - warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - return self.urlencoded_form or ODict([]) - - def set_form_urlencoded(self, odict): # pragma: no cover - warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) - self.urlencoded_form = odict - - def get_form_multipart(self): # pragma: no cover - warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) - return self.multipart_form or ODict([]) diff --git a/netlib/http/response.py b/netlib/http/response.py index 20074dca2..6d56fc1f9 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,14 +1,12 @@ from __future__ import absolute_import, print_function, division -import warnings from email.utils import parsedate_tz, formatdate, mktime_tz import time from . import cookies from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData +from .message import Message, _native, _always_bytes, MessageData, MultiDictView from .. import utils -from ..odict import ODict class ResponseData(MessageData): @@ -70,33 +68,32 @@ class Response(Message): def reason(self, reason): self.data.reason = _always_bytes(reason) - # FIXME @property def cookies(self): + # type: () -> MultiDictView """ - Get the contents of all Set-Cookie headers. + The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are + cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly) + are indicated by a Null value. - A possibly empty :py:class:`ODict`, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. + Caveats: + Updating the attr """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return ODict(ret) + return MultiDictView("cookies", self) + + @property + def _cookies(self): + h = self.headers.get_all("set-cookie") + return tuple(cookies.parse_set_cookie_headers(h)) - # FIXME @cookies.setter - def cookies(self, odict): - values = [] - for i in odict.lst: - header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) - values.append(header) - self.headers.set_all("set-cookie", values) + def cookies(self, all_cookies): + cookie_headers = [] + for k, v in all_cookies: + header = cookies.format_set_cookie_header(k, v[0], v[1]) + cookie_headers.append(header) + self.headers.set_all("set-cookie", cookie_headers) def refresh(self, now=None): """ diff --git a/netlib/multidict.py b/netlib/multidict.py index a7158bc5c..32d5bfc2e 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -1,163 +1,240 @@ -from __future__ import absolute_import, print_function, division - -from abc import ABCMeta, abstractmethod - -from typing import Tuple - -try: - from collections.abc import MutableMapping -except ImportError: # pragma: no cover - from collections import MutableMapping # Workaround for Python < 3.3 - -import six - -from .utils import Serializable - - -@six.add_metaclass(ABCMeta) -class MultiDict(MutableMapping, Serializable): - def __init__(self, fields=None): - - # it is important for us that .fields is immutable, so that we can easily - # detect changes to it. - self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] - - for key, value in self.fields: - if not isinstance(key, bytes) or not isinstance(value, bytes): - raise TypeError("MultiDict fields must be bytes.") - - def __repr__(self): - fields = tuple( - repr(field) - for field in self.fields - ) - return "{cls}[{fields}]".format( - cls=type(self).__name__, - fields=", ".join(fields) - ) - - @staticmethod - @abstractmethod - def _reduce_values(values): - pass - - @staticmethod - @abstractmethod - def _kconv(v): - pass - - def __getitem__(self, key): - values = self.get_all(key) - if not values: - raise KeyError(key) - return self._reduce_values(values) - - def __setitem__(self, key, value): - self.set_all(key, [value]) - - def __delitem__(self, key): - if key not in self: - raise KeyError(key) - key = self._kconv(key) - self.fields = tuple( - field for field in self.fields - if key != self._kconv(field[0]) - ) - - def __iter__(self): - seen = set() - for key, _ in self.fields: - key_kconv = self._kconv(key) - if key_kconv not in seen: - seen.add(key_kconv) - yield key - - def __len__(self): - return len(set(self._kconv(key) for key, _ in self.fields)) - - def __eq__(self, other): - if isinstance(other, MultiDict): - return self.fields == other.fields - return False - - def __ne__(self, other): - return not self.__eq__(other) - - def get_all(self, key): - """ - Return the list of items for a given key. - If that key is not in the MultiDict, - the return value will be an empty list. - """ - key = self._kconv(key) - return [ - value - for k, value in self.fields - if self._kconv(k) == key - ] - - def set_all(self, key, values): - """ - Remove the old values for a key and add new ones. - """ - key_kconv = self._kconv(key) - - new_fields = [] - for field in self.fields: - if self._kconv(field[0]) == key_kconv: - if values: - new_fields.append( - (key, values.pop(0)) - ) - else: - new_fields.append(field) - while values: - new_fields.append( - (key, values.pop(0)) - ) - self.fields = tuple(new_fields) - - def add(self, key, value): - self.insert(len(self.fields), key, value) - - def insert(self, index, key, value): - item = (key, value) - self.fields = self.fields[:index] + (item,) + self.fields[index:] - - def keys(self, multi=False): - return ( - k - for k, _ in self.items(multi) - ) - - def values(self, multi=False): - return ( - v - for _, v in self.items(multi) - ) - - def items(self, multi=False): - if multi: - return self.fields - else: - return super(MultiDict, self).items() - - def to_dict(self): - d = {} - for key in self: - values = self.get_all(key) - if len(values) == 1: - d[key] = values[0] - else: - d[key] = values - return d - - def get_state(self): - return self.fields - - def set_state(self, state): - self.fields = tuple(tuple(x) for x in state) - - @classmethod - def from_state(cls, state): - return cls(tuple(x) for x in state) +from __future__ import absolute_import, print_function, division + +from abc import ABCMeta, abstractmethod + +from typing import Tuple, TypeVar + +try: + from collections.abc import MutableMapping +except ImportError: # pragma: no cover + from collections import MutableMapping # Workaround for Python < 3.3 + +import six + +from .utils import Serializable + + +@six.add_metaclass(ABCMeta) +class MultiDict(MutableMapping, Serializable): + def __init__(self, fields=None): + + # it is important for us that .fields is immutable, so that we can easily + # detect changes to it. + self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] + + def __repr__(self): + fields = tuple( + repr(field) + for field in self.fields + ) + return "{cls}[{fields}]".format( + cls=type(self).__name__, + fields=", ".join(fields) + ) + + @staticmethod + @abstractmethod + def _reduce_values(values): + pass + + @staticmethod + @abstractmethod + def _kconv(v): + pass + + def __getitem__(self, key): + values = self.get_all(key) + if not values: + raise KeyError(key) + return self._reduce_values(values) + + def __setitem__(self, key, value): + self.set_all(key, [value]) + + def __delitem__(self, key): + if key not in self: + raise KeyError(key) + key = self._kconv(key) + self.fields = tuple( + field for field in self.fields + if key != self._kconv(field[0]) + ) + + def __iter__(self): + seen = set() + for key, _ in self.fields: + key_kconv = self._kconv(key) + if key_kconv not in seen: + seen.add(key_kconv) + yield key + + def __len__(self): + return len(set(self._kconv(key) for key, _ in self.fields)) + + def __eq__(self, other): + if isinstance(other, MultiDict): + return self.fields == other.fields + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_all(self, key): + """ + Return the list of all values for a given key. + If that key is not in the MultiDict, the return value will be an empty list. + """ + key = self._kconv(key) + return [ + value + for k, value in self.fields + if self._kconv(k) == key + ] + + def set_all(self, key, values): + """ + Remove the old values for a key and add new ones. + """ + key_kconv = self._kconv(key) + + new_fields = [] + for field in self.fields: + if self._kconv(field[0]) == key_kconv: + if values: + new_fields.append( + (key, values.pop(0)) + ) + else: + new_fields.append(field) + while values: + new_fields.append( + (key, values.pop(0)) + ) + self.fields = tuple(new_fields) + + def add(self, key, value): + """ + Add an additional value for the given key at the bottom. + """ + self.insert(len(self.fields), key, value) + + def insert(self, index, key, value): + """ + Insert an additional value for the given key at the specified position. + """ + item = (key, value) + self.fields = self.fields[:index] + (item,) + self.fields[index:] + + def keys(self, multi=False): + """ + Get all keys. + + Args: + multi(bool): + If True, one key per value will be returned. + If False, duplicate keys will only be returned once. + """ + return ( + k + for k, _ in self.items(multi) + ) + + def values(self, multi=False): + """ + Get all values. + + Args: + multi(bool): + If True, all values will be returned. + If False, only the first value per key will be returned. + """ + return ( + v + for _, v in self.items(multi) + ) + + def items(self, multi=False): + """ + Get all (key, value) tuples. + + Args: + multi(bool): + If True, all (key, value) pairs will be returned + If False, only the first (key, value) pair per unique key will be returned. + """ + if multi: + return self.fields + else: + return super(MultiDict, self).items() + + def to_dict(self): + """ + Get the MultiDict as a plain Python dict. + Keys with multiple values are returned as lists. + + Example: + + .. code-block:: python + + # Simple dict with duplicate values. + >>> d + MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + >>> d.to_dict() + { + "name": "value", + "a": ["false", "42"] + } + """ + d = {} + for key in self: + values = self.get_all(key) + if len(values) == 1: + d[key] = values[0] + else: + d[key] = values + return d + + def get_state(self): + return self.fields + + def set_state(self, state): + self.fields = tuple(tuple(x) for x in state) + + @classmethod + def from_state(cls, state): + return cls(tuple(x) for x in state) + + +@six.add_metaclass(ABCMeta) +class ImmutableMultiDict(MultiDict): + def _immutable(self, *_): + raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) + + __delitem__ = set_all = insert = _immutable + + def with_delitem(self, key): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).__delitem__(key) + return ret + + def with_set_all(self, key, values): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).set_all(key, values) + return ret + + def with_insert(self, index, key, value): + """ + Returns: + An updated ImmutableMultiDict. The original object will not be modified. + """ + ret = self.copy() + super(ImmutableMultiDict, ret).insert(index, key, value) + return ret diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index d0a258e9c..ac79b0936 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -98,7 +98,7 @@ def test_modify_form(): flow.request.headers["content-type"] = "" ex.run("request", flow) - assert list(flow.request.urlencoded_form.items()) == [("foo","bar")] + assert list(flow.request.urlencoded_form.items()) == [("foo", "bar")] def test_modify_querystring(): diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py index e2cee57ff..6f84c4ce8 100644 --- a/test/netlib/http/test_cookies.py +++ b/test/netlib/http/test_cookies.py @@ -197,24 +197,28 @@ def test_parse_set_cookie_header(): ], [ "one=uno", - ("one", "uno", []) + ("one", "uno", ()) ], [ "one=uno; foo=bar", - ("one", "uno", [["foo", "bar"]]) - ] + ("one", "uno", (("foo", "bar"),)) + ], + [ + "one=uno; foo=bar; foo=baz", + ("one", "uno", (("foo", "bar"), ("foo", "baz"))) + ], ] for s, expected in vals: ret = cookies.parse_set_cookie_header(s) if expected: assert ret[0] == expected[0] assert ret[1] == expected[1] - assert ret[2].lst == expected[2] + assert ret[2].items(multi=True) == expected[2] s2 = cookies.format_set_cookie_header(*ret) ret2 = cookies.parse_set_cookie_header(s2) assert ret2[0] == expected[0] assert ret2[1] == expected[1] - assert ret2[2].lst == expected[2] + assert ret2[2].items(multi=True) == expected[2] else: assert ret is None diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index 26593ee12..eefdc0917 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -251,7 +251,7 @@ class TestRequestUtils(object): def test_get_urlencoded_form(self): request = treq(content="foobar=baz") - assert request.urlencoded_form is None + assert not request.urlencoded_form request.headers["Content-Type"] = "application/x-www-form-urlencoded" assert list(request.urlencoded_form.items()) == [("foobar", "baz")] @@ -264,7 +264,7 @@ class TestRequestUtils(object): def test_get_multipart_form(self): request = treq(content="foobar") - assert request.multipart_form is None + assert not request.multipart_form request.headers["Content-Type"] = "multipart/form-data" assert list(request.multipart_form.items()) == [] diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py index 37273541a..cfd093d4c 100644 --- a/test/netlib/http/test_response.py +++ b/test/netlib/http/test_response.py @@ -6,6 +6,7 @@ import six import time from netlib.http import Headers +from netlib.http.cookies import CookieAttrs from netlib.odict import ODict, ODictCaseless from netlib.tutils import raises, tresp from .test_message import _test_passthrough_attr, _test_decoded_attr @@ -56,7 +57,7 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] + assert result["cookiename"] == ("cookievalue", CookieAttrs()) def test_get_cookies_with_parameters(self): resp = tresp() @@ -64,13 +65,13 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0][0] == "cookievalue" - attrs = result["cookiename"][0][1] + assert result["cookiename"][0] == "cookievalue" + attrs = result["cookiename"][1] assert len(attrs) == 4 - assert attrs["domain"] == ["example.com"] - assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] - assert attrs["path"] == ["/"] - assert attrs["httponly"] == [None] + assert attrs["domain"] == "example.com" + assert attrs["expires"] == "Wed Oct 21 16:29:41 2015" + assert attrs["path"] == "/" + assert attrs["httponly"] is None def test_get_cookies_no_value(self): resp = tresp() @@ -78,8 +79,8 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0][0] == "" - assert len(result["cookiename"][0][1]) == 2 + assert result["cookiename"][0] == "" + assert len(result["cookiename"][1]) == 2 def test_get_cookies_twocookies(self): resp = tresp() @@ -90,19 +91,16 @@ class TestResponseUtils(object): result = resp.cookies assert len(result) == 2 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] + assert result["cookiename"] == ("cookievalue", CookieAttrs()) assert "othercookie" in result - assert result["othercookie"][0] == ["othervalue", ODict()] + assert result["othercookie"] == ("othervalue", CookieAttrs()) def test_set_cookies(self): resp = tresp() - v = resp.cookies - v.add("foo", ["bar", ODictCaseless()]) - resp.cookies = v + resp.cookies["foo"] = ("bar", {}) - v = resp.cookies - assert len(v) == 1 - assert v["foo"] == [["bar", ODictCaseless()]] + assert len(resp.cookies) == 1 + assert resp.cookies["foo"] == ("bar", CookieAttrs()) def test_refresh(self): r = tresp() From 56b9ec09745a78f1fdb818fca77bc7e9eba01b8b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 18 May 2016 22:50:45 -0700 Subject: [PATCH 4/8] docs++ --- docs/dev/models.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/dev/models.rst b/docs/dev/models.rst index 8c4e68252..f2ddf2421 100644 --- a/docs/dev/models.rst +++ b/docs/dev/models.rst @@ -56,6 +56,17 @@ Datastructures :special-members: :no-undoc-members: + .. autoclass:: MultiDictView + + .. automethod:: get_all + .. automethod:: set_all + .. automethod:: add + .. automethod:: insert + .. automethod:: keys + .. automethod:: values + .. automethod:: items + .. automethod:: to_dict + .. autoclass:: decoded .. automodule:: mitmproxy.models From 560fc756aa87c06c563f3255a4ec5c2002538fd4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 20 May 2016 09:37:13 -0700 Subject: [PATCH 5/8] fix Header docs --- netlib/http/headers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 8959394c8..60d3f429d 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -44,9 +44,9 @@ class Headers(MultiDict): # Headers can also be created from a list of raw (header_name, header_value) byte tuples >>> h = Headers([ - [b"Host",b"example.com"], - [b"Accept",b"text/html"], - [b"accept",b"application/xml"] + (b"Host",b"example.com"), + (b"Accept",b"text/html"), + (b"accept",b"application/xml") ]) # Multiple headers are folded into a single header as per RFC7230 From b538138ead1dc8550f2d4e4a3f30ff70abb95f53 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 20 May 2016 11:04:27 -0700 Subject: [PATCH 6/8] tests++ --- netlib/http/cookies.py | 4 +- netlib/http/message.py | 2 +- netlib/multidict.py | 14 +- test/netlib/http/test_headers.py | 99 +------------- test/netlib/test_multidict.py | 217 +++++++++++++++++++++++++++++++ 5 files changed, 232 insertions(+), 104 deletions(-) create mode 100644 test/netlib/test_multidict.py diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index c5ac45918..88c768706 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -169,8 +169,8 @@ def parse_set_cookie_headers(headers): class CookieAttrs(ImmutableMultiDict): @staticmethod - def _kconv(v): - return v.lower() + def _kconv(key): + return key.lower() @staticmethod def _reduce_values(values): diff --git a/netlib/http/message.py b/netlib/http/message.py index 3c731ea6e..db4054b14 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -279,7 +279,7 @@ class MultiDictView(MultiDict): """ def __init__(self, attr, message): - if False: + if False: # pragma: no cover # We do not want to call the parent constructor here as that # would cause an unnecessary parse/unparse pass. # This is here to silence linters. Message diff --git a/netlib/multidict.py b/netlib/multidict.py index 32d5bfc2e..a359d46b0 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -35,12 +35,20 @@ class MultiDict(MutableMapping, Serializable): @staticmethod @abstractmethod def _reduce_values(values): - pass + """ + If a user accesses multidict["foo"], this method + reduces all values for "foo" to a single value that is returned. + For example, HTTP headers are folded, whereas we will just take + the first cookie we found with that name. + """ @staticmethod @abstractmethod - def _kconv(v): - pass + def _kconv(key): + """ + This method converts a key to its canonical representation. + For example, HTTP headers are case-insensitive, so this method returns key.lower(). + """ def __getitem__(self, key): values = self.get_all(key) diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 48d3b3233..cd2ca9d11 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -41,17 +41,7 @@ class TestHeaders(object): with raises(TypeError): Headers([[b"Host", u"not-bytes"]]) - def test_getitem(self): - headers = Headers(Host="example.com") - assert headers["Host"] == "example.com" - assert headers["host"] == "example.com" - with raises(KeyError): - _ = headers["Accept"] - - headers = self._2host() - assert headers["Host"] == "example.com, example.org" - - def test_str(self): + def test_bytes(self): headers = Headers(Host="example.com") assert bytes(headers) == b"Host: example.com\r\n" @@ -64,93 +54,6 @@ class TestHeaders(object): headers = Headers() assert bytes(headers) == b"" - def test_setitem(self): - headers = Headers() - headers["Host"] = "example.com" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == "example.com" - - headers["host"] = "example.org" - assert "Host" in headers - assert "host" in headers - assert headers["Host"] == "example.org" - - headers["accept"] = "text/plain" - assert len(headers) == 2 - assert "Accept" in headers - assert "Host" in headers - - headers = self._2host() - assert len(headers.fields) == 2 - headers["Host"] = "example.com" - assert len(headers.fields) == 1 - assert "Host" in headers - - def test_delitem(self): - headers = Headers(Host="example.com") - assert len(headers) == 1 - del headers["host"] - assert len(headers) == 0 - try: - del headers["host"] - except KeyError: - assert True - else: - assert False - - headers = self._2host() - del headers["Host"] - assert len(headers) == 0 - - def test_keys(self): - headers = Headers(Host="example.com") - assert list(headers.keys()) == ["Host"] - - headers = self._2host() - assert list(headers.keys()) == ["Host"] - - def test_eq_ne(self): - headers1 = Headers(Host="example.com") - headers2 = Headers(host="example.com") - assert not (headers1 == headers2) - assert headers1 != headers2 - - headers1 = Headers(Host="example.com") - headers2 = Headers(Host="example.com") - assert headers1 == headers2 - assert not (headers1 != headers2) - - assert headers1 != 42 - - def test_get_all(self): - headers = self._2host() - assert headers.get_all("host") == ["example.com", "example.org"] - assert headers.get_all("accept") == [] - - def test_set_all(self): - headers = Headers(Host="example.com") - headers.set_all("Accept", ["text/plain"]) - assert len(headers) == 2 - assert "accept" in headers - - headers = self._2host() - headers.set_all("Host", ["example.org"]) - assert headers["host"] == "example.org" - - headers.set_all("Host", ["example.org", "example.net"]) - assert headers["host"] == "example.org, example.net" - - def test_state(self): - headers = self._2host() - assert len(headers.get_state()) == 2 - assert headers == Headers.from_state(headers.get_state()) - - headers2 = Headers() - assert headers != headers2 - headers2.set_state(headers.get_state()) - assert headers == headers2 - def test_replace_simple(self): headers = Headers(Host="example.com", Accept="text/plain") replacements = headers.replace("Host: ", "X-Host: ") diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py new file mode 100644 index 000000000..ceea38064 --- /dev/null +++ b/test/netlib/test_multidict.py @@ -0,0 +1,217 @@ +from netlib import tutils +from netlib.multidict import MultiDict, ImmutableMultiDict + + +class _TMulti(object): + @staticmethod + def _reduce_values(values): + return values[0] + + @staticmethod + def _kconv(key): + return key.lower() + + +class TMultiDict(_TMulti, MultiDict): + pass + + +class TImmutableMultiDict(_TMulti, ImmutableMultiDict): + pass + + +class TestMultiDict(object): + @staticmethod + def _multi(): + return TMultiDict(( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam") + )) + + def test_init(self): + md = TMultiDict() + assert len(md) == 0 + + md = TMultiDict([("foo", "bar")]) + assert len(md) == 1 + assert md.fields == (("foo", "bar"),) + + def test_repr(self): + assert repr(self._multi()) == ( + "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]" + ) + + def test_getitem(self): + md = TMultiDict([("foo", "bar")]) + assert "foo" in md + assert "Foo" in md + assert md["foo"] == "bar" + + with tutils.raises(KeyError): + _ = md["bar"] + + md_multi = TMultiDict( + [("foo", "a"), ("foo", "b")] + ) + assert md_multi["foo"] == "a" + + def test_setitem(self): + md = TMultiDict() + md["foo"] = "bar" + assert md.fields == (("foo", "bar"),) + + md["foo"] = "baz" + assert md.fields == (("foo", "baz"),) + + md["bar"] = "bam" + assert md.fields == (("foo", "baz"), ("bar", "bam")) + + def test_delitem(self): + md = self._multi() + del md["foo"] + assert "foo" not in md + assert "bar" in md + + with tutils.raises(KeyError): + del md["foo"] + + del md["bar"] + assert md.fields == () + + def test_iter(self): + md = self._multi() + assert list(md.__iter__()) == ["foo", "bar"] + + def test_len(self): + md = TMultiDict() + assert len(md) == 0 + + md = self._multi() + assert len(md) == 2 + + def test_eq(self): + assert TMultiDict() == TMultiDict() + assert not (TMultiDict() == 42) + + md1 = self._multi() + md2 = self._multi() + assert md1 == md2 + md1.fields = md1.fields[1:] + md1.fields[:1] + assert not (md1 == md2) + + def test_ne(self): + assert not TMultiDict() != TMultiDict() + assert TMultiDict() != self._multi() + assert TMultiDict() != 42 + + def test_get_all(self): + md = self._multi() + assert md.get_all("foo") == ["bar"] + assert md.get_all("bar") == ["baz", "bam"] + assert md.get_all("baz") == [] + + def test_set_all(self): + md = TMultiDict() + md.set_all("foo", ["bar", "baz"]) + assert md.fields == (("foo", "bar"), ("foo", "baz")) + + md = TMultiDict(( + ("a", "b"), + ("x", "x"), + ("c", "d"), + ("X", "x"), + ("e", "f"), + )) + md.set_all("x", ["1", "2", "3"]) + assert md.fields == ( + ("a", "b"), + ("x", "1"), + ("c", "d"), + ("x", "2"), + ("e", "f"), + ("x", "3"), + ) + md.set_all("x", ["4"]) + assert md.fields == ( + ("a", "b"), + ("x", "4"), + ("c", "d"), + ("e", "f"), + ) + + def test_add(self): + md = self._multi() + md.add("foo", "foo") + assert md.fields == ( + ("foo", "bar"), + ("bar", "baz"), + ("Bar", "bam"), + ("foo", "foo") + ) + + def test_insert(self): + md = TMultiDict([("b", "b")]) + md.insert(0, "a", "a") + md.insert(2, "c", "c") + assert md.fields == (("a", "a"), ("b", "b"), ("c", "c")) + + def test_keys(self): + md = self._multi() + assert list(md.keys()) == ["foo", "bar"] + assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"] + + def test_values(self): + md = self._multi() + assert list(md.values()) == ["bar", "baz"] + assert list(md.values(multi=True)) == ["bar", "baz", "bam"] + + def test_items(self): + md = self._multi() + assert list(md.items()) == [("foo", "bar"), ("bar", "baz")] + assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")] + + def test_to_dict(self): + md = self._multi() + assert md.to_dict() == { + "foo": "bar", + "bar": ["baz", "bam"] + } + + def test_state(self): + md = self._multi() + assert len(md.get_state()) == 3 + assert md == TMultiDict.from_state(md.get_state()) + + md2 = TMultiDict() + assert md != md2 + md2.set_state(md.get_state()) + assert md == md2 + + +class TestImmutableMultiDict(object): + def test_modify(self): + md = TImmutableMultiDict() + with tutils.raises(TypeError): + md["foo"] = "bar" + + with tutils.raises(TypeError): + del md["foo"] + + with tutils.raises(TypeError): + md.add("foo", "bar") + + def test_with_delitem(self): + md = TImmutableMultiDict([("foo", "bar")]) + assert md.with_delitem("foo").fields == () + assert md.fields == (("foo", "bar"),) + + def test_with_set_all(self): + md = TImmutableMultiDict() + assert md.with_set_all("foo", ["bar"]).fields == (("foo", "bar"),) + assert md.fields == () + + def test_with_insert(self): + md = TImmutableMultiDict() + assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),) + assert md.fields == () \ No newline at end of file From a5c4cd034081d7dcdbd4b46bd69718edb45d4719 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 May 2016 11:37:36 +1200 Subject: [PATCH 7/8] A clearer implementation of MultiDictView This makes MultiDictView work with a simple getter/setter pair, rather than using attributes with implicit leading underscores. Also move MultiDictView into multidict.py and adds some simple unit tests. --- netlib/http/__init__.py | 4 +- netlib/http/message.py | 69 ----------------------------------- netlib/http/request.py | 59 ++++++++++++++++++++---------- netlib/http/response.py | 20 ++++++---- netlib/multidict.py | 49 ++++++++++++++++++++----- test/netlib/test_multidict.py | 26 ++++++++++++- 6 files changed, 119 insertions(+), 108 deletions(-) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9fafa28fc..c4eb1d58d 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -2,13 +2,13 @@ from __future__ import absolute_import, print_function, division from .request import Request from .response import Response from .headers import Headers -from .message import MultiDictView, decoded +from .message import decoded from . import http1, http2, status_codes __all__ = [ "Request", "Response", "Headers", - "MultiDictView", "decoded", + "decoded", "http1", "http2", "status_codes", ] diff --git a/netlib/http/message.py b/netlib/http/message.py index db4054b14..9b0180cf7 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -236,72 +236,3 @@ class decoded(object): def __exit__(self, type, value, tb): if self.ce: self.message.encode(self.ce) - - -class MultiDictView(MultiDict): - """ - Some parts in HTTP (Cookies, URL query strings, ...) require a specific data structure: A MultiDict. - It behaves mostly like an ordered dict but it can have several values for the same key. - - The MultiDictView provides a MultiDict *view* on an :py:class:`Request` or :py:class:`Response`. - That is, it represents a part of the request as a MultiDict, but doesn't contain state/data themselves. - - For example, ``request.cookies`` provides a view on the ``Cookie: ...`` header. - Any change to ``request.cookies`` will also modify the ``Cookie`` header. - Any change to the ``Cookie`` header will also modify ``request.cookies``. - - Example: - - .. code-block:: python - - # Cookies are represented as a MultiDict. - >>> request.cookies - MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] - - # MultiDicts mostly behave like a normal dict. - >>> request.cookies["name"] - "value" - - # If there is more than one value, only the first value is returned. - >>> request.cookies["a"] - "false" - - # `.get_all(key)` returns a list of all values. - >>> request.cookies.get_all("a") - ["false", "42"] - - # Changes to the headers are immediately reflected in the cookies. - >>> request.cookies - MultiDictView[("name", "value"), ...] - >>> del request.headers["Cookie"] - >>> request.cookies - MultiDictView[] # empty now - """ - - def __init__(self, attr, message): - if False: # pragma: no cover - # We do not want to call the parent constructor here as that - # would cause an unnecessary parse/unparse pass. - # This is here to silence linters. Message - super(MultiDictView, self).__init__(None) - self._attr = attr - self._message = message # type: Message - - @staticmethod - def _kconv(key): - # All request-attributes are case-sensitive. - return key - - @staticmethod - def _reduce_values(values): - # We just return the first element if - # multiple elements exist with the same key. - return values[0] - - @property - def fields(self): - return getattr(self._message, "_" + self._attr) - - @fields.setter - def fields(self, value): - setattr(self._message, self._attr, value) diff --git a/netlib/http/request.py b/netlib/http/request.py index ae28084b5..056a2d93a 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -10,8 +10,9 @@ from netlib import utils from netlib.http import cookies from netlib.odict import ODict from .. import encoding +from ..multidict import MultiDictView from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData, MultiDictView +from .message import Message, _native, _always_bytes, MessageData # This regex extracts & splits the host header into host and port. # Handles the edge case of IPv6 addresses containing colons. @@ -228,20 +229,25 @@ class Request(Message): """ The request query string as an :py:class:`MultiDictView` object. """ - return MultiDictView("query", self) + return MultiDictView( + self._get_query, + self._set_query + ) - @property - def _query(self): + def _get_query(self): _, _, _, _, query, _ = urllib.parse.urlparse(self.url) return tuple(utils.urldecode(query)) - @query.setter - def query(self, value): + def _set_query(self, value): query = utils.urlencode(value) scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) _, _, _, self.path = utils.parse_url( urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) + @query.setter + def query(self, value): + self._set_query(value) + @property def cookies(self): # type: () -> MultiDictView @@ -250,16 +256,21 @@ class Request(Message): An empty :py:class:`MultiDictView` object if the cookie monster ate them all. """ - return MultiDictView("cookies", self) + return MultiDictView( + self._get_cookies, + self._set_cookies + ) - @property - def _cookies(self): + def _get_cookies(self): h = self.headers.get_all("Cookie") return tuple(cookies.parse_cookie_headers(h)) + def _set_cookies(self, value): + self.headers["cookie"] = cookies.format_cookie_header(value) + @cookies.setter def cookies(self, value): - self.headers["cookie"] = cookies.format_cookie_header(value) + self._set_cookies(value) @property def path_components(self): @@ -322,17 +333,18 @@ class Request(Message): An empty MultiDictView if the content-type indicates non-form data or the content could not be parsed. """ - return MultiDictView("urlencoded_form", self) + return MultiDictView( + self._get_urlencoded_form, + self._set_urlencoded_form + ) - @property - def _urlencoded_form(self): + def _get_urlencoded_form(self): is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() if is_valid_content_type: return tuple(utils.urldecode(self.content)) return () - @urlencoded_form.setter - def urlencoded_form(self, value): + def _set_urlencoded_form(self, value): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. This will overwrite the existing content if there is one. @@ -340,21 +352,30 @@ class Request(Message): self.headers["content-type"] = "application/x-www-form-urlencoded" self.content = utils.urlencode(value) + @urlencoded_form.setter + def urlencoded_form(self, value): + self._set_urlencoded_form(value) + @property def multipart_form(self): """ The multipart form data as an :py:class:`MultipartFormDict` object. None if the content-type indicates non-form data. """ - return MultiDictView("multipart_form", self) + return MultiDictView( + self._get_multipart_form, + self._set_multipart_form + ) - @property - def _multipart_form(self): + def _get_multipart_form(self): is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() if is_valid_content_type: return utils.multipartdecode(self.headers, self.content) return () + def _set_multipart_form(self, value): + raise NotImplementedError() + @multipart_form.setter def multipart_form(self, value): - raise NotImplementedError() + self._set_multipart_form(value) diff --git a/netlib/http/response.py b/netlib/http/response.py index 6d56fc1f9..7d272e104 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -5,7 +5,8 @@ import time from . import cookies from .headers import Headers -from .message import Message, _native, _always_bytes, MessageData, MultiDictView +from .message import Message, _native, _always_bytes, MessageData +from ..multidict import MultiDictView from .. import utils @@ -80,21 +81,26 @@ class Response(Message): Caveats: Updating the attr """ - return MultiDictView("cookies", self) + return MultiDictView( + self._get_cookies, + self._set_cookies + ) - @property - def _cookies(self): + def _get_cookies(self): h = self.headers.get_all("set-cookie") return tuple(cookies.parse_set_cookie_headers(h)) - @cookies.setter - def cookies(self, all_cookies): + def _set_cookies(self, value): cookie_headers = [] - for k, v in all_cookies: + for k, v in value: header = cookies.format_set_cookie_header(k, v[0], v[1]) cookie_headers.append(header) self.headers.set_all("set-cookie", cookie_headers) + @cookies.setter + def cookies(self, value): + self._set_cookies(value) + def refresh(self, now=None): """ This fairly complex and heuristic function refreshes a server diff --git a/netlib/multidict.py b/netlib/multidict.py index a359d46b0..3af7979b2 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -15,13 +15,7 @@ from .utils import Serializable @six.add_metaclass(ABCMeta) -class MultiDict(MutableMapping, Serializable): - def __init__(self, fields=None): - - # it is important for us that .fields is immutable, so that we can easily - # detect changes to it. - self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] - +class _MultiDict(MutableMapping, Serializable): def __repr__(self): fields = tuple( repr(field) @@ -97,7 +91,7 @@ class MultiDict(MutableMapping, Serializable): value for k, value in self.fields if self._kconv(k) == key - ] + ] def set_all(self, key, values): """ @@ -173,7 +167,7 @@ class MultiDict(MutableMapping, Serializable): if multi: return self.fields else: - return super(MultiDict, self).items() + return super(_MultiDict, self).items() def to_dict(self): """ @@ -213,6 +207,12 @@ class MultiDict(MutableMapping, Serializable): return cls(tuple(x) for x in state) +class MultiDict(_MultiDict): + def __init__(self, fields=None): + super(MultiDict, self).__init__() + self.fields = tuple(fields) if fields else tuple() # type: Tuple[Tuple[bytes, bytes], ...] + + @six.add_metaclass(ABCMeta) class ImmutableMultiDict(MultiDict): def _immutable(self, *_): @@ -246,3 +246,34 @@ class ImmutableMultiDict(MultiDict): ret = self.copy() super(ImmutableMultiDict, ret).insert(index, key, value) return ret + + +class MultiDictView(_MultiDict): + """ + The MultiDictView provides the MultiDict interface over calculated data. + The view itself contains no state - data is retrieved from the parent on + request, and stored back to the parent on change. + """ + def __init__(self, getter, setter): + self._getter = getter + self._setter = setter + super(MultiDictView, self).__init__() + + @staticmethod + def _kconv(key): + # All request-attributes are case-sensitive. + return key + + @staticmethod + def _reduce_values(values): + # We just return the first element if + # multiple elements exist with the same key. + return values[0] + + @property + def fields(self): + return self._getter() + + @fields.setter + def fields(self, value): + return self._setter(value) diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py index ceea38064..5bb65e3fd 100644 --- a/test/netlib/test_multidict.py +++ b/test/netlib/test_multidict.py @@ -1,5 +1,5 @@ from netlib import tutils -from netlib.multidict import MultiDict, ImmutableMultiDict +from netlib.multidict import MultiDict, ImmutableMultiDict, MultiDictView class _TMulti(object): @@ -214,4 +214,26 @@ class TestImmutableMultiDict(object): def test_with_insert(self): md = TImmutableMultiDict() assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),) - assert md.fields == () \ No newline at end of file + + +class TParent(object): + def __init__(self): + self.vals = tuple() + + def setter(self, vals): + self.vals = vals + + def getter(self): + return self.vals + + +class TestMultiDictView(object): + def test_modify(self): + p = TParent() + tv = MultiDictView(p.getter, p.setter) + assert len(tv) == 0 + tv["a"] = "b" + assert p.vals == (("a", "b"),) + tv["c"] = "b" + assert p.vals == (("a", "b"), ("c", "b")) + assert tv["a"] == "b" From 43d796553292a9bbf2c174e31e5e0e39e1068be1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 21 May 2016 15:00:52 +1200 Subject: [PATCH 8/8] Clean un-needed imports --- test/netlib/http/test_request.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index eefdc0917..fae7aefe5 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -3,9 +3,7 @@ from __future__ import absolute_import, print_function, division import six -from netlib import utils from netlib.http import Headers -from netlib.odict import ODict from netlib.tutils import treq, raises from .test_message import _test_decoded_attr, _test_passthrough_attr