diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 131e8ce5f..b55874ca6 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -14,6 +14,7 @@ if six.PY2: # pragma: no cover return x def _always_bytes(x): + strutils.always_bytes(x, "utf-8", "replace") # raises a TypeError if x != str/bytes/None. return x else: # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. diff --git a/netlib/strutils.py b/netlib/strutils.py index 4cb3b8056..d43c2aab0 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -8,7 +8,10 @@ import six def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes + elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None: + return unicode_or_bytes + else: + raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__)) def native(s, *encoding_opts): diff --git a/pathod/language/http2.py b/pathod/language/http2.py index c0313baa3..519ee6999 100644 --- a/pathod/language/http2.py +++ b/pathod/language/http2.py @@ -189,7 +189,7 @@ class Response(_HTTP2Message): resp = http.Response( b'HTTP/2.0', - self.status_code.string(), + int(self.status_code.string()), b'', headers, body, diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index a2aa91b46..7b162664a 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -6,7 +6,7 @@ import time import hyperframe.frame from hpack.hpack import Encoder, Decoder -from netlib import utils, strutils +from netlib import utils from netlib.http import http2 import netlib.http.headers import netlib.http.response @@ -201,7 +201,7 @@ class HTTP2StateProtocol(object): headers = response.headers.copy() if ':status' not in headers: - headers.insert(0, b':status', strutils.always_bytes(response.status_code)) + headers.insert(0, b':status', str(response.status_code).encode()) if hasattr(response, 'stream_id'): stream_id = response.stream_id diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index ad2bc5487..e8752c52f 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -43,6 +43,15 @@ class TestHeaders(object): with raises(TypeError): Headers([[b"Host", u"not-bytes"]]) + def test_set(self): + headers = Headers() + headers[u"foo"] = u"1" + headers[b"bar"] = b"2" + headers["baz"] = b"3" + with raises(TypeError): + headers["foobar"] = 42 + assert len(headers) == 3 + def test_bytes(self): headers = Headers(Host="example.com") assert bytes(headers) == b"Host: example.com\r\n" diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py index 5be254a3e..0f58cac57 100644 --- a/test/netlib/test_strutils.py +++ b/test/netlib/test_strutils.py @@ -8,6 +8,8 @@ def test_always_bytes(): assert strutils.always_bytes("foo") == b"foo" with tutils.raises(ValueError): strutils.always_bytes(u"\u2605", "ascii") + with tutils.raises(TypeError): + strutils.always_bytes(42, "ascii") def test_native():