This commit is contained in:
Maximilian Hils 2016-05-20 11:04:27 -07:00
parent 560fc756aa
commit b538138ead
5 changed files with 232 additions and 104 deletions

View File

@ -169,8 +169,8 @@ def parse_set_cookie_headers(headers):
class CookieAttrs(ImmutableMultiDict):
@staticmethod
def _kconv(v):
return v.lower()
def _kconv(key):
return key.lower()
@staticmethod
def _reduce_values(values):

View File

@ -279,7 +279,7 @@ class MultiDictView(MultiDict):
"""
def __init__(self, attr, message):
if False:
if False: # pragma: no cover
# We do not want to call the parent constructor here as that
# would cause an unnecessary parse/unparse pass.
# This is here to silence linters. Message

View File

@ -35,12 +35,20 @@ class MultiDict(MutableMapping, Serializable):
@staticmethod
@abstractmethod
def _reduce_values(values):
pass
"""
If a user accesses multidict["foo"], this method
reduces all values for "foo" to a single value that is returned.
For example, HTTP headers are folded, whereas we will just take
the first cookie we found with that name.
"""
@staticmethod
@abstractmethod
def _kconv(v):
pass
def _kconv(key):
"""
This method converts a key to its canonical representation.
For example, HTTP headers are case-insensitive, so this method returns key.lower().
"""
def __getitem__(self, key):
values = self.get_all(key)

View File

@ -41,17 +41,7 @@ class TestHeaders(object):
with raises(TypeError):
Headers([[b"Host", u"not-bytes"]])
def test_getitem(self):
headers = Headers(Host="example.com")
assert headers["Host"] == "example.com"
assert headers["host"] == "example.com"
with raises(KeyError):
_ = headers["Accept"]
headers = self._2host()
assert headers["Host"] == "example.com, example.org"
def test_str(self):
def test_bytes(self):
headers = Headers(Host="example.com")
assert bytes(headers) == b"Host: example.com\r\n"
@ -64,93 +54,6 @@ class TestHeaders(object):
headers = Headers()
assert bytes(headers) == b""
def test_setitem(self):
headers = Headers()
headers["Host"] = "example.com"
assert "Host" in headers
assert "host" in headers
assert headers["Host"] == "example.com"
headers["host"] = "example.org"
assert "Host" in headers
assert "host" in headers
assert headers["Host"] == "example.org"
headers["accept"] = "text/plain"
assert len(headers) == 2
assert "Accept" in headers
assert "Host" in headers
headers = self._2host()
assert len(headers.fields) == 2
headers["Host"] = "example.com"
assert len(headers.fields) == 1
assert "Host" in headers
def test_delitem(self):
headers = Headers(Host="example.com")
assert len(headers) == 1
del headers["host"]
assert len(headers) == 0
try:
del headers["host"]
except KeyError:
assert True
else:
assert False
headers = self._2host()
del headers["Host"]
assert len(headers) == 0
def test_keys(self):
headers = Headers(Host="example.com")
assert list(headers.keys()) == ["Host"]
headers = self._2host()
assert list(headers.keys()) == ["Host"]
def test_eq_ne(self):
headers1 = Headers(Host="example.com")
headers2 = Headers(host="example.com")
assert not (headers1 == headers2)
assert headers1 != headers2
headers1 = Headers(Host="example.com")
headers2 = Headers(Host="example.com")
assert headers1 == headers2
assert not (headers1 != headers2)
assert headers1 != 42
def test_get_all(self):
headers = self._2host()
assert headers.get_all("host") == ["example.com", "example.org"]
assert headers.get_all("accept") == []
def test_set_all(self):
headers = Headers(Host="example.com")
headers.set_all("Accept", ["text/plain"])
assert len(headers) == 2
assert "accept" in headers
headers = self._2host()
headers.set_all("Host", ["example.org"])
assert headers["host"] == "example.org"
headers.set_all("Host", ["example.org", "example.net"])
assert headers["host"] == "example.org, example.net"
def test_state(self):
headers = self._2host()
assert len(headers.get_state()) == 2
assert headers == Headers.from_state(headers.get_state())
headers2 = Headers()
assert headers != headers2
headers2.set_state(headers.get_state())
assert headers == headers2
def test_replace_simple(self):
headers = Headers(Host="example.com", Accept="text/plain")
replacements = headers.replace("Host: ", "X-Host: ")

View File

@ -0,0 +1,217 @@
from netlib import tutils
from netlib.multidict import MultiDict, ImmutableMultiDict
class _TMulti(object):
@staticmethod
def _reduce_values(values):
return values[0]
@staticmethod
def _kconv(key):
return key.lower()
class TMultiDict(_TMulti, MultiDict):
pass
class TImmutableMultiDict(_TMulti, ImmutableMultiDict):
pass
class TestMultiDict(object):
@staticmethod
def _multi():
return TMultiDict((
("foo", "bar"),
("bar", "baz"),
("Bar", "bam")
))
def test_init(self):
md = TMultiDict()
assert len(md) == 0
md = TMultiDict([("foo", "bar")])
assert len(md) == 1
assert md.fields == (("foo", "bar"),)
def test_repr(self):
assert repr(self._multi()) == (
"TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]"
)
def test_getitem(self):
md = TMultiDict([("foo", "bar")])
assert "foo" in md
assert "Foo" in md
assert md["foo"] == "bar"
with tutils.raises(KeyError):
_ = md["bar"]
md_multi = TMultiDict(
[("foo", "a"), ("foo", "b")]
)
assert md_multi["foo"] == "a"
def test_setitem(self):
md = TMultiDict()
md["foo"] = "bar"
assert md.fields == (("foo", "bar"),)
md["foo"] = "baz"
assert md.fields == (("foo", "baz"),)
md["bar"] = "bam"
assert md.fields == (("foo", "baz"), ("bar", "bam"))
def test_delitem(self):
md = self._multi()
del md["foo"]
assert "foo" not in md
assert "bar" in md
with tutils.raises(KeyError):
del md["foo"]
del md["bar"]
assert md.fields == ()
def test_iter(self):
md = self._multi()
assert list(md.__iter__()) == ["foo", "bar"]
def test_len(self):
md = TMultiDict()
assert len(md) == 0
md = self._multi()
assert len(md) == 2
def test_eq(self):
assert TMultiDict() == TMultiDict()
assert not (TMultiDict() == 42)
md1 = self._multi()
md2 = self._multi()
assert md1 == md2
md1.fields = md1.fields[1:] + md1.fields[:1]
assert not (md1 == md2)
def test_ne(self):
assert not TMultiDict() != TMultiDict()
assert TMultiDict() != self._multi()
assert TMultiDict() != 42
def test_get_all(self):
md = self._multi()
assert md.get_all("foo") == ["bar"]
assert md.get_all("bar") == ["baz", "bam"]
assert md.get_all("baz") == []
def test_set_all(self):
md = TMultiDict()
md.set_all("foo", ["bar", "baz"])
assert md.fields == (("foo", "bar"), ("foo", "baz"))
md = TMultiDict((
("a", "b"),
("x", "x"),
("c", "d"),
("X", "x"),
("e", "f"),
))
md.set_all("x", ["1", "2", "3"])
assert md.fields == (
("a", "b"),
("x", "1"),
("c", "d"),
("x", "2"),
("e", "f"),
("x", "3"),
)
md.set_all("x", ["4"])
assert md.fields == (
("a", "b"),
("x", "4"),
("c", "d"),
("e", "f"),
)
def test_add(self):
md = self._multi()
md.add("foo", "foo")
assert md.fields == (
("foo", "bar"),
("bar", "baz"),
("Bar", "bam"),
("foo", "foo")
)
def test_insert(self):
md = TMultiDict([("b", "b")])
md.insert(0, "a", "a")
md.insert(2, "c", "c")
assert md.fields == (("a", "a"), ("b", "b"), ("c", "c"))
def test_keys(self):
md = self._multi()
assert list(md.keys()) == ["foo", "bar"]
assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"]
def test_values(self):
md = self._multi()
assert list(md.values()) == ["bar", "baz"]
assert list(md.values(multi=True)) == ["bar", "baz", "bam"]
def test_items(self):
md = self._multi()
assert list(md.items()) == [("foo", "bar"), ("bar", "baz")]
assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")]
def test_to_dict(self):
md = self._multi()
assert md.to_dict() == {
"foo": "bar",
"bar": ["baz", "bam"]
}
def test_state(self):
md = self._multi()
assert len(md.get_state()) == 3
assert md == TMultiDict.from_state(md.get_state())
md2 = TMultiDict()
assert md != md2
md2.set_state(md.get_state())
assert md == md2
class TestImmutableMultiDict(object):
def test_modify(self):
md = TImmutableMultiDict()
with tutils.raises(TypeError):
md["foo"] = "bar"
with tutils.raises(TypeError):
del md["foo"]
with tutils.raises(TypeError):
md.add("foo", "bar")
def test_with_delitem(self):
md = TImmutableMultiDict([("foo", "bar")])
assert md.with_delitem("foo").fields == ()
assert md.fields == (("foo", "bar"),)
def test_with_set_all(self):
md = TImmutableMultiDict()
assert md.with_set_all("foo", ["bar"]).fields == (("foo", "bar"),)
assert md.fields == ()
def test_with_insert(self):
md = TImmutableMultiDict()
assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),)
assert md.fields == ()