raise TypeError on invalid header assignment, fix #1562

This commit is contained in:
Maximilian Hils 2016-09-21 19:21:32 -07:00
parent 1e5a5b03f8
commit 770936f1f9
4 changed files with 16 additions and 1 deletions

View File

@ -14,6 +14,7 @@ if six.PY2: # pragma: no cover
return x return x
def _always_bytes(x): def _always_bytes(x):
strutils.always_bytes(x, "utf-8", "replace") # raises a TypeError if x != str/bytes/None.
return x return x
else: else:
# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.

View File

@ -8,7 +8,10 @@ import six
def always_bytes(unicode_or_bytes, *encode_args): def always_bytes(unicode_or_bytes, *encode_args):
if isinstance(unicode_or_bytes, six.text_type): if isinstance(unicode_or_bytes, six.text_type):
return unicode_or_bytes.encode(*encode_args) return unicode_or_bytes.encode(*encode_args)
return unicode_or_bytes elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None:
return unicode_or_bytes
else:
raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__))
def native(s, *encoding_opts): def native(s, *encoding_opts):

View File

@ -43,6 +43,15 @@ class TestHeaders(object):
with raises(TypeError): with raises(TypeError):
Headers([[b"Host", u"not-bytes"]]) Headers([[b"Host", u"not-bytes"]])
def test_set(self):
headers = Headers()
headers[u"foo"] = u"1"
headers[b"bar"] = b"2"
headers["baz"] = b"3"
with raises(TypeError):
headers["foobar"] = 42
assert len(headers) == 3
def test_bytes(self): def test_bytes(self):
headers = Headers(Host="example.com") headers = Headers(Host="example.com")
assert bytes(headers) == b"Host: example.com\r\n" assert bytes(headers) == b"Host: example.com\r\n"

View File

@ -8,6 +8,8 @@ def test_always_bytes():
assert strutils.always_bytes("foo") == b"foo" assert strutils.always_bytes("foo") == b"foo"
with tutils.raises(ValueError): with tutils.raises(ValueError):
strutils.always_bytes(u"\u2605", "ascii") strutils.always_bytes(u"\u2605", "ascii")
with tutils.raises(TypeError):
strutils.always_bytes(42, "ascii")
def test_native(): def test_native():