python 3++

This commit is contained in:
Maximilian Hils 2015-09-21 00:44:17 +02:00
parent daebd1bd27
commit 73586b1be9
13 changed files with 250 additions and 206 deletions

View File

@ -17,7 +17,7 @@ matrix:
- libssl-dev
- python: 3.5
script:
- py.test3 -n 4 -k "not http2 and not websockets and not wsgi and not models" .
- py.test -n 4 -k "not http2" .
- python: pypy
- python: pypy
env: OPENSSL=1.0.2

View File

@ -8,27 +8,25 @@ import zlib
from .utils import always_byte_args
ENCODINGS = {b"identity", b"gzip", b"deflate"}
ENCODINGS = {"identity", "gzip", "deflate"}
@always_byte_args("ascii", "ignore")
def decode(e, content):
encoding_map = {
b"identity": identity,
b"gzip": decode_gzip,
b"deflate": decode_deflate,
"identity": identity,
"gzip": decode_gzip,
"deflate": decode_deflate,
}
if e not in encoding_map:
return None
return encoding_map[e](content)
@always_byte_args("ascii", "ignore")
def encode(e, content):
encoding_map = {
b"identity": identity,
b"gzip": encode_gzip,
b"deflate": encode_deflate,
"identity": identity,
"gzip": encode_gzip,
"deflate": encode_deflate,
}
if e not in encoding_map:
return None

View File

@ -3,7 +3,7 @@ import copy
from ..odict import ODict
from .. import utils, encoding
from ..utils import always_bytes, always_byte_args
from ..utils import always_bytes, always_byte_args, native
from . import cookies
import six
@ -254,7 +254,7 @@ class Request(Message):
def __repr__(self):
if self.host and self.port:
hostport = "{}:{}".format(self.host, self.port)
hostport = "{}:{}".format(native(self.host,"idna"), self.port)
else:
hostport = ""
path = self.path or ""
@ -279,14 +279,14 @@ class Request(Message):
Modifies this request to remove headers that will compress the
resource's data.
"""
self.headers["Accept-Encoding"] = b"identity"
self.headers["Accept-Encoding"] = "identity"
def constrain_encoding(self):
"""
Limits the permissible Accept-Encoding values, based on what we can
decode appropriately.
"""
accept_encoding = self.headers.get(b"Accept-Encoding")
accept_encoding = native(self.headers.get("Accept-Encoding"), "ascii")
if accept_encoding:
self.headers["Accept-Encoding"] = (
', '.join(
@ -309,9 +309,9 @@ class Request(Message):
indicates non-form data.
"""
if self.body:
if HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower():
if HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower():
return self.get_form_urlencoded()
elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower():
elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower():
return self.get_form_multipart()
return ODict([])
@ -321,12 +321,12 @@ class Request(Message):
Returns an empty ODict if there is no data or the content-type
indicates non-form data.
"""
if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower():
if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower():
return ODict(utils.urldecode(self.body))
return ODict([])
def get_form_multipart(self):
if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower():
if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower():
return ODict(
utils.multipartdecode(
self.headers,
@ -351,7 +351,7 @@ class Request(Message):
Components are unquoted.
"""
_, _, path, _, _, _ = urllib.parse.urlparse(self.url)
return [urllib.parse.unquote(i) for i in path.split(b"/") if i]
return [urllib.parse.unquote(native(i,"ascii")) for i in path.split(b"/") if i]
def set_path_components(self, lst):
"""
@ -360,7 +360,7 @@ class Request(Message):
Components are quoted.
"""
lst = [urllib.parse.quote(i, safe="") for i in lst]
path = b"/" + b"/".join(lst)
path = always_bytes("/" + "/".join(lst))
scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
self.url = urllib.parse.urlunparse(
[scheme, netloc, path, params, query, fragment]
@ -408,11 +408,11 @@ class Request(Message):
def pretty_url(self, hostheader):
if self.form_out == "authority": # upstream proxy mode
return "%s:%s" % (self.pretty_host(hostheader), self.port)
return b"%s:%d" % (always_bytes(self.pretty_host(hostheader)), self.port)
return utils.unparse_url(self.scheme,
self.pretty_host(hostheader),
self.port,
self.path).encode('ascii')
self.path)
def get_cookies(self):
"""
@ -420,7 +420,7 @@ class Request(Message):
"""
ret = ODict()
for i in self.headers.get_all("Cookie"):
ret.extend(cookies.parse_cookie_header(i))
ret.extend(cookies.parse_cookie_header(native(i,"ascii")))
return ret
def set_cookies(self, odict):
@ -441,7 +441,7 @@ class Request(Message):
self.host,
self.port,
self.path
).encode('ascii')
)
@url.setter
def url(self, url):
@ -499,7 +499,7 @@ class Response(Message):
"""
ret = []
for header in self.headers.get_all("Set-Cookie"):
v = cookies.parse_set_cookie_header(header)
v = cookies.parse_set_cookie_header(native(header, "ascii"))
if v:
name, value, attrs = v
ret.append([name, [value, attrs]])

View File

@ -7,7 +7,7 @@ from contextlib import contextmanager
import six
import sys
from . import utils
from . import utils, tcp
from .http import Request, Response, Headers
@ -15,7 +15,6 @@ def treader(bytes):
"""
Construct a tcp.Read object from bytes.
"""
from . import tcp # TODO: move to top once cryptography is on Python 3.5
fp = BytesIO(bytes)
return tcp.Reader(fp)
@ -106,7 +105,7 @@ def treq(**kwargs):
port=22,
path=b"/path",
http_version=b"HTTP/1.1",
headers=Headers(header=b"qvalue"),
headers=Headers(header="qvalue"),
body=b"content"
)
default.update(kwargs)

View File

@ -9,6 +9,41 @@ import six
from six.moves import urllib
def always_bytes(unicode_or_bytes, *encode_args):
if isinstance(unicode_or_bytes, six.text_type):
return unicode_or_bytes.encode(*encode_args)
return unicode_or_bytes
def always_byte_args(*encode_args):
"""Decorator that transparently encodes all arguments passed as unicode"""
def decorator(fun):
def _fun(*args, **kwargs):
args = [always_bytes(arg, *encode_args) for arg in args]
kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)}
return fun(*args, **kwargs)
return _fun
return decorator
def native(s, encoding="latin-1"):
"""
Convert :py:class:`bytes` or :py:class:`unicode` to the native
:py:class:`str` type, using latin1 encoding if conversion is necessary.
https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types
"""
if not isinstance(s, (six.binary_type, six.text_type)):
raise TypeError("%r is neither bytes nor unicode" % s)
if six.PY3:
if isinstance(s, six.binary_type):
return s.decode(encoding)
else:
if isinstance(s, six.text_type):
return s.encode(encoding)
return s
def isascii(bytes):
try:
bytes.decode("ascii")
@ -238,6 +273,7 @@ def get_header_tokens(headers, key):
return [token.strip() for token in tokens]
@always_byte_args()
def hostport(scheme, host, port):
"""
Returns the host component, with a port specifcation if needed.
@ -323,20 +359,3 @@ def multipartdecode(headers, content):
r.append((key, value))
return r
return []
def always_bytes(unicode_or_bytes, *encode_args):
if isinstance(unicode_or_bytes, six.text_type):
return unicode_or_bytes.encode(*encode_args)
return unicode_or_bytes
def always_byte_args(*encode_args):
"""Decorator that transparently encodes all arguments passed as unicode"""
def decorator(fun):
def _fun(*args, **kwargs):
args = [always_bytes(arg, *encode_args) for arg in args]
kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)}
return fun(*args, **kwargs)
return _fun
return decorator

View File

@ -2,13 +2,14 @@ from __future__ import absolute_import
import os
import struct
import io
import warnings
import six
from .protocol import Masker
from netlib import tcp
from netlib import utils
DEFAULT = object()
MAX_16_BIT_INT = (1 << 16)
MAX_64_BIT_INT = (1 << 64)
@ -33,9 +34,9 @@ class FrameHeader(object):
rsv1=False,
rsv2=False,
rsv3=False,
masking_key=DEFAULT,
mask=DEFAULT,
length_code=DEFAULT
masking_key=None,
mask=None,
length_code=None
):
if not 0 <= opcode < 2 ** 4:
raise ValueError("opcode must be 0-16")
@ -46,18 +47,18 @@ class FrameHeader(object):
self.rsv2 = rsv2
self.rsv3 = rsv3
if length_code is DEFAULT:
if length_code is None:
self.length_code = self._make_length_code(self.payload_length)
else:
self.length_code = length_code
if mask is DEFAULT and masking_key is DEFAULT:
if mask is None and masking_key is None:
self.mask = False
self.masking_key = ""
elif mask is DEFAULT:
self.masking_key = b""
elif mask is None:
self.mask = 1
self.masking_key = masking_key
elif masking_key is DEFAULT:
elif masking_key is None:
self.mask = mask
self.masking_key = os.urandom(4)
else:
@ -81,7 +82,7 @@ class FrameHeader(object):
else:
return 127
def human_readable(self):
def __repr__(self):
vals = [
"ws frame:",
OPCODE.get_name(self.opcode, hex(self.opcode)).lower()
@ -98,7 +99,11 @@ class FrameHeader(object):
vals.append(" %s" % utils.pretty_size(self.payload_length))
return "".join(vals)
def to_bytes(self):
def human_readable(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return repr(self)
def __bytes__(self):
first_byte = utils.setbit(0, 7, self.fin)
first_byte = utils.setbit(first_byte, 6, self.rsv1)
first_byte = utils.setbit(first_byte, 5, self.rsv2)
@ -107,7 +112,7 @@ class FrameHeader(object):
second_byte = utils.setbit(self.length_code, 7, self.mask)
b = chr(first_byte) + chr(second_byte)
b = six.int2byte(first_byte) + six.int2byte(second_byte)
if self.payload_length < 126:
pass
@ -119,10 +124,17 @@ class FrameHeader(object):
# '!Q' = pack as 64 bit unsigned long long
# add 8 bytes extended payload length
b += struct.pack('!Q', self.payload_length)
if self.masking_key is not None:
if self.masking_key:
b += self.masking_key
return b
if six.PY2:
__str__ = __bytes__
def to_bytes(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return bytes(self)
@classmethod
def from_file(cls, fp):
"""
@ -154,7 +166,7 @@ class FrameHeader(object):
if mask_bit == 1:
masking_key = fp.safe_read(4)
else:
masking_key = None
masking_key = False
return cls(
fin=fin,
@ -169,7 +181,9 @@ class FrameHeader(object):
)
def __eq__(self, other):
return self.to_bytes() == other.to_bytes()
if isinstance(other, FrameHeader):
return bytes(self) == bytes(other)
return False
class Frame(object):
@ -200,7 +214,7 @@ class Frame(object):
+---------------------------------------------------------------+
"""
def __init__(self, payload="", **kwargs):
def __init__(self, payload=b"", **kwargs):
self.payload = payload
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
self.header = FrameHeader(**kwargs)
@ -216,7 +230,7 @@ class Frame(object):
masking_key = os.urandom(4)
else:
mask_bit = 0
masking_key = None
masking_key = False
return cls(
message,
@ -234,28 +248,37 @@ class Frame(object):
"""
return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))
def human_readable(self):
ret = self.header.human_readable()
def __repr__(self):
ret = repr(self.header)
if self.payload:
ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload)
ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii")
return ret
def __repr__(self):
return self.header.human_readable()
def human_readable(self):
warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning)
return repr(self)
def to_bytes(self):
def __bytes__(self):
"""
Serialize the frame to wire format. Returns a string.
"""
b = self.header.to_bytes()
b = bytes(self.header)
if self.header.masking_key:
b += Masker(self.header.masking_key)(self.payload)
else:
b += self.payload
return b
if six.PY2:
__str__ = __bytes__
def to_bytes(self):
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
return bytes(self)
def to_file(self, writer):
writer.write(self.to_bytes())
warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning)
writer.write(bytes(self))
writer.flush()
@classmethod
@ -286,4 +309,6 @@ class Frame(object):
)
def __eq__(self, other):
return self.to_bytes() == other.to_bytes()
if isinstance(other, Frame):
return bytes(self) == bytes(other)
return False

View File

@ -17,11 +17,12 @@ from __future__ import absolute_import
import base64
import hashlib
import os
import binascii
import six
from ..http import Headers
from .. import utils
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"
HEADER_WEBSOCKET_KEY = 'sec-websocket-key'
@ -41,14 +42,21 @@ class Masker(object):
def __init__(self, key):
self.key = key
self.masks = [six.byte2int(byte) for byte in key]
self.offset = 0
def mask(self, offset, data):
result = ""
for c in data:
result += chr(ord(c) ^ self.masks[offset % 4])
offset += 1
result = bytearray(data)
if six.PY2:
for i in range(len(data)):
result[i] ^= ord(self.key[offset % 4])
offset += 1
result = str(result)
else:
for i in range(len(data)):
result[i] ^= self.key[offset % 4]
offset += 1
result = bytes(result)
return result
def __call__(self, data):
@ -73,37 +81,35 @@ class WebsocketsProtocol(object):
"""
if not key:
key = base64.b64encode(os.urandom(16)).decode('utf-8')
return Headers([
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
(HEADER_WEBSOCKET_KEY, key),
(HEADER_WEBSOCKET_VERSION, version)
])
return Headers(**{
HEADER_WEBSOCKET_KEY: key,
HEADER_WEBSOCKET_VERSION: version,
"Connection": "Upgrade",
"Upgrade": "websocket",
})
@classmethod
def server_handshake_headers(self, key):
"""
The server response is a valid HTTP 101 response.
"""
return Headers(
[
('Connection', 'Upgrade'),
('Upgrade', 'websocket'),
(HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key))
]
)
return Headers(**{
HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key),
"Connection": "Upgrade",
"Upgrade": "websocket",
})
@classmethod
def check_client_handshake(self, headers):
if headers.get("upgrade") != "websocket":
if headers.get("upgrade") != b"websocket":
return
return headers.get(HEADER_WEBSOCKET_KEY)
@classmethod
def check_server_handshake(self, headers):
if headers.get("upgrade") != "websocket":
if headers.get("upgrade") != b"websocket":
return
return headers.get(HEADER_WEBSOCKET_ACCEPT)
@ -111,5 +117,5 @@ class WebsocketsProtocol(object):
@classmethod
def create_server_nonce(self, client_nonce):
return base64.b64encode(
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest())
)

View File

@ -1,14 +1,15 @@
from __future__ import (absolute_import, print_function, division)
from io import BytesIO
from io import BytesIO, StringIO
import urllib
import time
import traceback
import six
from six.moves import urllib
from netlib.utils import always_bytes, native
from . import http, tcp
class ClientConn(object):
def __init__(self, address):
@ -24,9 +25,10 @@ class Flow(object):
class Request(object):
def __init__(self, scheme, method, path, headers, body):
def __init__(self, scheme, method, path, http_version, headers, body):
self.scheme, self.method, self.path = scheme, method, path
self.headers, self.body = headers, body
self.http_version = http_version
def date_time_string():
@ -53,38 +55,38 @@ class WSGIAdaptor(object):
self.app, self.domain, self.port, self.sversion = app, domain, port, sversion
def make_environ(self, flow, errsoc, **extra):
if '?' in flow.request.path:
path_info, query = flow.request.path.split('?', 1)
path = native(flow.request.path)
if '?' in path:
path_info, query = native(path).split('?', 1)
else:
path_info = flow.request.path
path_info = path
query = ''
environ = {
'wsgi.version': (1, 0),
'wsgi.url_scheme': flow.request.scheme,
'wsgi.url_scheme': native(flow.request.scheme),
'wsgi.input': BytesIO(flow.request.body or b""),
'wsgi.errors': errsoc,
'wsgi.multithread': True,
'wsgi.multiprocess': False,
'wsgi.run_once': False,
'SERVER_SOFTWARE': self.sversion,
'REQUEST_METHOD': flow.request.method,
'REQUEST_METHOD': native(flow.request.method),
'SCRIPT_NAME': '',
'PATH_INFO': urllib.unquote(path_info),
'PATH_INFO': urllib.parse.unquote(path_info),
'QUERY_STRING': query,
'CONTENT_TYPE': flow.request.headers.get('Content-Type', ''),
'CONTENT_LENGTH': flow.request.headers.get('Content-Length', ''),
'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', '')),
'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', '')),
'SERVER_NAME': self.domain,
'SERVER_PORT': str(self.port),
# FIXME: We need to pick up the protocol read from the request.
'SERVER_PROTOCOL': "HTTP/1.1",
'SERVER_PROTOCOL': native(flow.request.http_version),
}
environ.update(extra)
if flow.client_conn.address:
environ["REMOTE_ADDR"], environ[
"REMOTE_PORT"] = flow.client_conn.address()
environ["REMOTE_ADDR"] = native(flow.client_conn.address.host)
environ["REMOTE_PORT"] = flow.client_conn.address.port
for key, value in flow.request.headers.items():
key = 'HTTP_' + key.upper().replace('-', '_')
key = 'HTTP_' + native(key).upper().replace('-', '_')
if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
environ[key] = value
return environ
@ -99,7 +101,7 @@ class WSGIAdaptor(object):
<h1>Internal Server Error</h1>
<pre>%s"</pre>
</html>
""".strip() % s
""".strip() % s.encode()
if not headers_sent:
soc.write(b"HTTP/1.1 500 Internal Server Error\r\n")
soc.write(b"Content-Type: text/html\r\n")
@ -117,7 +119,7 @@ class WSGIAdaptor(object):
def write(data):
if not state["headers_sent"]:
soc.write(b"HTTP/1.1 %s\r\n" % state["status"])
soc.write(b"HTTP/1.1 %s\r\n" % state["status"].encode())
headers = state["headers"]
if 'server' not in headers:
headers["Server"] = self.sversion
@ -132,18 +134,17 @@ class WSGIAdaptor(object):
def start_response(status, headers, exc_info=None):
if exc_info:
try:
if state["headers_sent"]:
six.reraise(*exc_info)
finally:
exc_info = None
if state["headers_sent"]:
six.reraise(*exc_info)
elif state["status"]:
raise AssertionError('Response already started')
state["status"] = status
state["headers"] = http.Headers(headers)
return write
state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k,v in headers])
if exc_info:
self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2]))
state["headers_sent"] = True
errs = BytesIO()
errs = six.BytesIO()
try:
dataiter = self.app(
self.make_environ(request, errs, **env), start_response
@ -155,7 +156,7 @@ class WSGIAdaptor(object):
except Exception as e:
try:
s = traceback.format_exc()
errs.write(s)
errs.write(s.encode("utf-8", "replace"))
self.error_page(soc, state["headers_sent"], s)
except Exception: # pragma: no cover
pass

View File

@ -58,20 +58,20 @@ class TestRequest(object):
req = tutils.treq()
req.headers["Accept-Encoding"] = "foobar"
req.anticomp()
assert req.headers["Accept-Encoding"] == "identity"
assert req.headers["Accept-Encoding"] == b"identity"
def test_constrain_encoding(self):
req = tutils.treq()
req.headers["Accept-Encoding"] = "identity, gzip, foo"
req.constrain_encoding()
assert "foo" not in req.headers["Accept-Encoding"]
assert b"foo" not in req.headers["Accept-Encoding"]
def test_update_host(self):
req = tutils.treq()
req.headers["Host"] = ""
req.host = "foobar"
req.update_host_header()
assert req.headers["Host"] == "foobar"
assert req.headers["Host"] == b"foobar"
def test_get_form(self):
req = tutils.treq()
@ -132,7 +132,7 @@ class TestRequest(object):
def test_set_path_components(self):
req = tutils.treq()
req.set_path_components(["foo", "bar"])
req.set_path_components([b"foo", b"bar"])
# TODO: add meaningful assertions
def test_get_query(self):
@ -140,7 +140,7 @@ class TestRequest(object):
assert req.get_query().lst == []
req.url = "http://localhost:80/foo?bar=42"
assert req.get_query().lst == [("bar", "42")]
assert req.get_query().lst == [(b"bar", b"42")]
def test_set_query(self):
req = tutils.treq()
@ -167,12 +167,12 @@ class TestRequest(object):
def test_pretty_url(self):
req = tutils.treq()
req.form_out = "authority"
assert req.pretty_url(True) == "address:22"
assert req.pretty_url(False) == "address:22"
assert req.pretty_url(True) == b"address:22"
assert req.pretty_url(False) == b"address:22"
req.form_out = "relative"
assert req.pretty_url(True) == "http://address:22/path"
assert req.pretty_url(False) == "http://address:22/path"
assert req.pretty_url(True) == b"http://address:22/path"
assert req.pretty_url(False) == b"http://address:22/path"
def test_get_cookies_none(self):
headers = Headers()
@ -213,11 +213,11 @@ class TestRequest(object):
def test_set_url(self):
r = tutils.treq(form_in="absolute")
r.url = "https://otheraddress:42/ORLY"
assert r.scheme == "https"
assert r.host == "otheraddress"
r.url = b"https://otheraddress:42/ORLY"
assert r.scheme == b"https"
assert r.host == b"otheraddress"
assert r.port == 42
assert r.path == "/ORLY"
assert r.path == b"/ORLY"
try:
r.url = "//localhost:80/foo@bar"
@ -374,8 +374,8 @@ class TestResponse(object):
def test_get_cookies_twocookies(self):
resp = tutils.tresp()
resp.headers = Headers([
["Set-Cookie", "cookiename=cookievalue"],
["Set-Cookie", "othercookie=othervalue"]
[b"Set-Cookie", b"cookiename=cookievalue"],
[b"Set-Cookie", b"othercookie=othervalue"]
])
result = resp.get_cookies()
assert len(result) == 2
@ -399,8 +399,8 @@ class TestHeaders(object):
def _2host(self):
return Headers(
[
["Host", "example.com"],
["host", "example.org"]
[b"Host", b"example.com"],
[b"host", b"example.org"]
]
)
@ -408,37 +408,37 @@ class TestHeaders(object):
headers = Headers()
assert len(headers) == 0
headers = Headers([["Host", "example.com"]])
headers = Headers([[b"Host", b"example.com"]])
assert len(headers) == 1
assert headers["Host"] == "example.com"
assert headers["Host"] == b"example.com"
headers = Headers(Host="example.com")
assert len(headers) == 1
assert headers["Host"] == "example.com"
assert headers["Host"] == b"example.com"
headers = Headers(
[["Host", "invalid"]],
[[b"Host", b"invalid"]],
Host="example.com"
)
assert len(headers) == 1
assert headers["Host"] == "example.com"
assert headers["Host"] == b"example.com"
headers = Headers(
[["Host", "invalid"], ["Accept", "text/plain"]],
[[b"Host", b"invalid"], [b"Accept", b"text/plain"]],
Host="example.com"
)
assert len(headers) == 2
assert headers["Host"] == "example.com"
assert headers["Accept"] == "text/plain"
assert headers["Host"] == b"example.com"
assert headers["Accept"] == b"text/plain"
def test_getitem(self):
headers = Headers(Host="example.com")
assert headers["Host"] == "example.com"
assert headers["host"] == "example.com"
assert headers["Host"] == b"example.com"
assert headers["host"] == b"example.com"
tutils.raises(KeyError, headers.__getitem__, "Accept")
headers = self._2host()
assert headers["Host"] == "example.com, example.org"
assert headers["Host"] == b"example.com, example.org"
def test_str(self):
headers = Headers(Host="example.com")
@ -458,12 +458,12 @@ class TestHeaders(object):
headers["Host"] = "example.com"
assert "Host" in headers
assert "host" in headers
assert headers["Host"] == "example.com"
assert headers["Host"] == b"example.com"
headers["host"] = "example.org"
assert "Host" in headers
assert "host" in headers
assert headers["Host"] == "example.org"
assert headers["Host"] == b"example.org"
headers["accept"] = "text/plain"
assert len(headers) == 2
@ -494,12 +494,10 @@ class TestHeaders(object):
def test_keys(self):
headers = Headers(Host="example.com")
assert len(headers.keys()) == 1
assert headers.keys()[0] == "Host"
assert list(headers.keys()) == [b"Host"]
headers = self._2host()
assert len(headers.keys()) == 1
assert headers.keys()[0] == "Host"
assert list(headers.keys()) == [b"Host"]
def test_eq_ne(self):
headers1 = Headers(Host="example.com")
@ -516,7 +514,7 @@ class TestHeaders(object):
def test_get_all(self):
headers = self._2host()
assert headers.get_all("host") == ["example.com", "example.org"]
assert headers.get_all("host") == [b"example.com", b"example.org"]
assert headers.get_all("accept") == []
def test_set_all(self):
@ -527,10 +525,10 @@ class TestHeaders(object):
headers = self._2host()
headers.set_all("Host", ["example.org"])
assert headers["host"] == "example.org"
assert headers["host"] == b"example.org"
headers.set_all("Host", ["example.org", "example.net"])
assert headers["host"] == "example.org, example.net"
assert headers["host"] == b"example.org, example.net"
def test_state(self):
headers = self._2host()

View File

@ -4,8 +4,6 @@ from netlib import encoding
def test_identity():
assert b"string" == encoding.decode("identity", b"string")
assert b"string" == encoding.encode("identity", b"string")
assert b"string" == encoding.encode(b"identity", b"string")
assert b"string" == encoding.decode(b"identity", b"string")
assert not encoding.encode("nonexistent", b"string")
assert not encoding.decode("nonexistent encoding", b"string")

View File

@ -5,8 +5,8 @@ from netlib.http import Headers
def tflow():
headers = Headers(test="value")
req = wsgi.Request("http", "GET", "/", headers, "")
headers = Headers(test=b"value")
req = wsgi.Request("http", "GET", "/", "HTTP/1.1", headers, "")
return wsgi.Flow(("127.0.0.1", 8888), req)
@ -20,7 +20,7 @@ class TestApp:
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers)
return ['Hello', ' world!\n']
return [b'Hello', b' world!\n']
class TestWSGI:
@ -47,8 +47,8 @@ class TestWSGI:
assert not err
val = wfile.getvalue()
assert "Hello world" in val
assert "Server:" in val
assert b"Hello world" in val
assert b"Server:" in val
def _serve(self, app):
w = wsgi.WSGIAdaptor(app, "foo", 80, "version")
@ -77,7 +77,7 @@ class TestWSGI:
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers)
start_response(status, response_headers)
assert "Internal Server Error" in self._serve(app)
assert b"Internal Server Error" in self._serve(app)
def test_serve_single_err(self):
def app(environ, start_response):
@ -88,7 +88,8 @@ class TestWSGI:
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers, ei)
assert "Internal Server Error" in self._serve(app)
yield b""
assert b"Internal Server Error" in self._serve(app)
def test_serve_double_err(self):
def app(environ, start_response):
@ -99,7 +100,7 @@ class TestWSGI:
status = '200 OK'
response_headers = [('Content-type', 'text/plain')]
start_response(status, response_headers)
yield "aaa"
yield b"aaa"
start_response(status, response_headers, ei)
yield "bbb"
assert "Internal Server Error" in self._serve(app)
yield b"bbb"
assert b"Internal Server Error" in self._serve(app)

View File

@ -41,7 +41,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
key = self.protocol.check_client_handshake(req.headers)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble + "\r\n")
self.wfile.write(preamble.encode() + b"\r\n")
headers = self.protocol.server_handshake_headers(key)
self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
@ -62,11 +62,11 @@ class WebSocketsClient(tcp.TCPClient):
def connect(self):
super(WebSocketsClient, self).connect()
preamble = 'GET / HTTP/1.1'
self.wfile.write(preamble + "\r\n")
preamble = b'GET / HTTP/1.1'
self.wfile.write(preamble + b"\r\n")
headers = self.protocol.client_handshake_headers()
self.client_nonce = headers["sec-websocket-key"]
self.wfile.write(str(headers) + "\r\n")
self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
resp = read_response(self.rfile, treq(method="GET"))
@ -101,7 +101,7 @@ class TestWebSockets(tservers.ServerTestBase):
assert response == msg
def test_simple_echo(self):
self.echo("hello I'm the client")
self.echo(b"hello I'm the client")
def test_frame_sizes(self):
# length can fit in the the 7 bit payload length
@ -161,10 +161,10 @@ class BadHandshakeHandler(WebSocketsEchoHandler):
client_hs = read_request(self.rfile)
self.protocol.check_client_handshake(client_hs.headers)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble + "\r\n")
headers = self.protocol.server_handshake_headers("malformed key")
self.wfile.write(str(headers) + "\r\n")
preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble.encode())
headers = self.protocol.server_handshake_headers(b"malformed key")
self.wfile.write(bytes(headers) + b"\r\n")
self.wfile.flush()
self.handshake_done = True
@ -180,7 +180,7 @@ class TestBadHandshake(tservers.ServerTestBase):
def test(self):
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()
client.send_message("hello")
client.send_message(b"hello")
class TestFrameHeader:
@ -188,8 +188,7 @@ class TestFrameHeader:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.FrameHeader(*args, **kwargs)
bytes = f.to_bytes()
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
assert f == f2
round()
round(fin=1)
@ -201,11 +200,11 @@ class TestFrameHeader:
round(payload_length=1000)
round(payload_length=10000)
round(opcode=websockets.OPCODE.PING)
round(masking_key="test")
round(masking_key=b"test")
def test_human_readable(self):
f = websockets.FrameHeader(
masking_key="test",
masking_key=b"test",
fin=True,
payload_length=10
)
@ -214,23 +213,23 @@ class TestFrameHeader:
assert f.human_readable()
def test_funky(self):
f = websockets.FrameHeader(masking_key="test", mask=False)
f = websockets.FrameHeader(masking_key=b"test", mask=False)
bytes = f.to_bytes()
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
assert not f2.mask
def test_violations(self):
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
tutils.raises("masking key", websockets.FrameHeader, masking_key="x")
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
def test_automask(self):
f = websockets.FrameHeader(mask=True)
assert f.masking_key
f = websockets.FrameHeader(masking_key="foob")
f = websockets.FrameHeader(masking_key=b"foob")
assert f.mask
f = websockets.FrameHeader(masking_key="foob", mask=0)
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
assert not f.mask
assert f.masking_key
@ -240,31 +239,31 @@ class TestFrame:
def test_roundtrip(self):
def round(*args, **kwargs):
f = websockets.Frame(*args, **kwargs)
bytes = f.to_bytes()
f2 = websockets.Frame.from_file(tutils.treader(bytes))
raw = bytes(f)
f2 = websockets.Frame.from_file(tutils.treader(raw))
assert f == f2
round("test")
round("test", fin=1)
round("test", rsv1=1)
round("test", opcode=websockets.OPCODE.PING)
round("test", masking_key="test")
round(b"test")
round(b"test", fin=1)
round(b"test", rsv1=1)
round(b"test", opcode=websockets.OPCODE.PING)
round(b"test", masking_key=b"test")
def test_human_readable(self):
f = websockets.Frame()
assert f.human_readable()
assert repr(f)
def test_masker():
tests = [
["a"],
["four"],
["fourf"],
["fourfive"],
["a", "aasdfasdfa", "asdf"],
["a" * 50, "aasdfasdfa", "asdf"],
[b"a"],
[b"four"],
[b"fourf"],
[b"fourfive"],
[b"a", b"aasdfasdfa", b"asdf"],
[b"a" * 50, b"aasdfasdfa", b"asdf"],
]
for i in tests:
m = websockets.Masker("abcd")
data = "".join([m(t) for t in i])
data2 = websockets.Masker("abcd")(data)
assert data2 == "".join(i)
m = websockets.Masker(b"abcd")
data = b"".join([m(t) for t in i])
data2 = websockets.Masker(b"abcd")(data)
assert data2 == b"".join(i)