mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 23:09:44 +00:00
python 3++
This commit is contained in:
parent
daebd1bd27
commit
73586b1be9
@ -17,7 +17,7 @@ matrix:
|
||||
- libssl-dev
|
||||
- python: 3.5
|
||||
script:
|
||||
- py.test3 -n 4 -k "not http2 and not websockets and not wsgi and not models" .
|
||||
- py.test -n 4 -k "not http2" .
|
||||
- python: pypy
|
||||
- python: pypy
|
||||
env: OPENSSL=1.0.2
|
||||
|
@ -8,27 +8,25 @@ import zlib
|
||||
from .utils import always_byte_args
|
||||
|
||||
|
||||
ENCODINGS = {b"identity", b"gzip", b"deflate"}
|
||||
ENCODINGS = {"identity", "gzip", "deflate"}
|
||||
|
||||
|
||||
@always_byte_args("ascii", "ignore")
|
||||
def decode(e, content):
|
||||
encoding_map = {
|
||||
b"identity": identity,
|
||||
b"gzip": decode_gzip,
|
||||
b"deflate": decode_deflate,
|
||||
"identity": identity,
|
||||
"gzip": decode_gzip,
|
||||
"deflate": decode_deflate,
|
||||
}
|
||||
if e not in encoding_map:
|
||||
return None
|
||||
return encoding_map[e](content)
|
||||
|
||||
|
||||
@always_byte_args("ascii", "ignore")
|
||||
def encode(e, content):
|
||||
encoding_map = {
|
||||
b"identity": identity,
|
||||
b"gzip": encode_gzip,
|
||||
b"deflate": encode_deflate,
|
||||
"identity": identity,
|
||||
"gzip": encode_gzip,
|
||||
"deflate": encode_deflate,
|
||||
}
|
||||
if e not in encoding_map:
|
||||
return None
|
||||
|
@ -3,7 +3,7 @@ import copy
|
||||
|
||||
from ..odict import ODict
|
||||
from .. import utils, encoding
|
||||
from ..utils import always_bytes, always_byte_args
|
||||
from ..utils import always_bytes, always_byte_args, native
|
||||
from . import cookies
|
||||
|
||||
import six
|
||||
@ -254,7 +254,7 @@ class Request(Message):
|
||||
|
||||
def __repr__(self):
|
||||
if self.host and self.port:
|
||||
hostport = "{}:{}".format(self.host, self.port)
|
||||
hostport = "{}:{}".format(native(self.host,"idna"), self.port)
|
||||
else:
|
||||
hostport = ""
|
||||
path = self.path or ""
|
||||
@ -279,14 +279,14 @@ class Request(Message):
|
||||
Modifies this request to remove headers that will compress the
|
||||
resource's data.
|
||||
"""
|
||||
self.headers["Accept-Encoding"] = b"identity"
|
||||
self.headers["Accept-Encoding"] = "identity"
|
||||
|
||||
def constrain_encoding(self):
|
||||
"""
|
||||
Limits the permissible Accept-Encoding values, based on what we can
|
||||
decode appropriately.
|
||||
"""
|
||||
accept_encoding = self.headers.get(b"Accept-Encoding")
|
||||
accept_encoding = native(self.headers.get("Accept-Encoding"), "ascii")
|
||||
if accept_encoding:
|
||||
self.headers["Accept-Encoding"] = (
|
||||
', '.join(
|
||||
@ -309,9 +309,9 @@ class Request(Message):
|
||||
indicates non-form data.
|
||||
"""
|
||||
if self.body:
|
||||
if HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower():
|
||||
if HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower():
|
||||
return self.get_form_urlencoded()
|
||||
elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower():
|
||||
elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower():
|
||||
return self.get_form_multipart()
|
||||
return ODict([])
|
||||
|
||||
@ -321,12 +321,12 @@ class Request(Message):
|
||||
Returns an empty ODict if there is no data or the content-type
|
||||
indicates non-form data.
|
||||
"""
|
||||
if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower():
|
||||
if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower():
|
||||
return ODict(utils.urldecode(self.body))
|
||||
return ODict([])
|
||||
|
||||
def get_form_multipart(self):
|
||||
if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower():
|
||||
if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower():
|
||||
return ODict(
|
||||
utils.multipartdecode(
|
||||
self.headers,
|
||||
@ -351,7 +351,7 @@ class Request(Message):
|
||||
Components are unquoted.
|
||||
"""
|
||||
_, _, path, _, _, _ = urllib.parse.urlparse(self.url)
|
||||
return [urllib.parse.unquote(i) for i in path.split(b"/") if i]
|
||||
return [urllib.parse.unquote(native(i,"ascii")) for i in path.split(b"/") if i]
|
||||
|
||||
def set_path_components(self, lst):
|
||||
"""
|
||||
@ -360,7 +360,7 @@ class Request(Message):
|
||||
Components are quoted.
|
||||
"""
|
||||
lst = [urllib.parse.quote(i, safe="") for i in lst]
|
||||
path = b"/" + b"/".join(lst)
|
||||
path = always_bytes("/" + "/".join(lst))
|
||||
scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
|
||||
self.url = urllib.parse.urlunparse(
|
||||
[scheme, netloc, path, params, query, fragment]
|
||||
@ -408,11 +408,11 @@ class Request(Message):
|
||||
|
||||
def pretty_url(self, hostheader):
|
||||
if self.form_out == "authority": # upstream proxy mode
|
||||
return "%s:%s" % (self.pretty_host(hostheader), self.port)
|
||||
return b"%s:%d" % (always_bytes(self.pretty_host(hostheader)), self.port)
|
||||
return utils.unparse_url(self.scheme,
|
||||
self.pretty_host(hostheader),
|
||||
self.port,
|
||||
self.path).encode('ascii')
|
||||
self.path)
|
||||
|
||||
def get_cookies(self):
|
||||
"""
|
||||
@ -420,7 +420,7 @@ class Request(Message):
|
||||
"""
|
||||
ret = ODict()
|
||||
for i in self.headers.get_all("Cookie"):
|
||||
ret.extend(cookies.parse_cookie_header(i))
|
||||
ret.extend(cookies.parse_cookie_header(native(i,"ascii")))
|
||||
return ret
|
||||
|
||||
def set_cookies(self, odict):
|
||||
@ -441,7 +441,7 @@ class Request(Message):
|
||||
self.host,
|
||||
self.port,
|
||||
self.path
|
||||
).encode('ascii')
|
||||
)
|
||||
|
||||
@url.setter
|
||||
def url(self, url):
|
||||
@ -499,7 +499,7 @@ class Response(Message):
|
||||
"""
|
||||
ret = []
|
||||
for header in self.headers.get_all("Set-Cookie"):
|
||||
v = cookies.parse_set_cookie_header(header)
|
||||
v = cookies.parse_set_cookie_header(native(header, "ascii"))
|
||||
if v:
|
||||
name, value, attrs = v
|
||||
ret.append([name, [value, attrs]])
|
||||
|
@ -7,7 +7,7 @@ from contextlib import contextmanager
|
||||
import six
|
||||
import sys
|
||||
|
||||
from . import utils
|
||||
from . import utils, tcp
|
||||
from .http import Request, Response, Headers
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@ def treader(bytes):
|
||||
"""
|
||||
Construct a tcp.Read object from bytes.
|
||||
"""
|
||||
from . import tcp # TODO: move to top once cryptography is on Python 3.5
|
||||
fp = BytesIO(bytes)
|
||||
return tcp.Reader(fp)
|
||||
|
||||
@ -106,7 +105,7 @@ def treq(**kwargs):
|
||||
port=22,
|
||||
path=b"/path",
|
||||
http_version=b"HTTP/1.1",
|
||||
headers=Headers(header=b"qvalue"),
|
||||
headers=Headers(header="qvalue"),
|
||||
body=b"content"
|
||||
)
|
||||
default.update(kwargs)
|
||||
|
@ -9,6 +9,41 @@ import six
|
||||
from six.moves import urllib
|
||||
|
||||
|
||||
def always_bytes(unicode_or_bytes, *encode_args):
|
||||
if isinstance(unicode_or_bytes, six.text_type):
|
||||
return unicode_or_bytes.encode(*encode_args)
|
||||
return unicode_or_bytes
|
||||
|
||||
|
||||
def always_byte_args(*encode_args):
|
||||
"""Decorator that transparently encodes all arguments passed as unicode"""
|
||||
def decorator(fun):
|
||||
def _fun(*args, **kwargs):
|
||||
args = [always_bytes(arg, *encode_args) for arg in args]
|
||||
kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)}
|
||||
return fun(*args, **kwargs)
|
||||
return _fun
|
||||
return decorator
|
||||
|
||||
|
||||
def native(s, encoding="latin-1"):
|
||||
"""
|
||||
Convert :py:class:`bytes` or :py:class:`unicode` to the native
|
||||
:py:class:`str` type, using latin1 encoding if conversion is necessary.
|
||||
|
||||
https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types
|
||||
"""
|
||||
if not isinstance(s, (six.binary_type, six.text_type)):
|
||||
raise TypeError("%r is neither bytes nor unicode" % s)
|
||||
if six.PY3:
|
||||
if isinstance(s, six.binary_type):
|
||||
return s.decode(encoding)
|
||||
else:
|
||||
if isinstance(s, six.text_type):
|
||||
return s.encode(encoding)
|
||||
return s
|
||||
|
||||
|
||||
def isascii(bytes):
|
||||
try:
|
||||
bytes.decode("ascii")
|
||||
@ -238,6 +273,7 @@ def get_header_tokens(headers, key):
|
||||
return [token.strip() for token in tokens]
|
||||
|
||||
|
||||
@always_byte_args()
|
||||
def hostport(scheme, host, port):
|
||||
"""
|
||||
Returns the host component, with a port specifcation if needed.
|
||||
@ -323,20 +359,3 @@ def multipartdecode(headers, content):
|
||||
r.append((key, value))
|
||||
return r
|
||||
return []
|
||||
|
||||
|
||||
def always_bytes(unicode_or_bytes, *encode_args):
|
||||
if isinstance(unicode_or_bytes, six.text_type):
|
||||
return unicode_or_bytes.encode(*encode_args)
|
||||
return unicode_or_bytes
|
||||
|
||||
|
||||
def always_byte_args(*encode_args):
|
||||
"""Decorator that transparently encodes all arguments passed as unicode"""
|
||||
def decorator(fun):
|
||||
def _fun(*args, **kwargs):
|
||||
args = [always_bytes(arg, *encode_args) for arg in args]
|
||||
kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)}
|
||||
return fun(*args, **kwargs)
|
||||
return _fun
|
||||
return decorator
|
||||
|
@ -2,13 +2,14 @@ from __future__ import absolute_import
|
||||
import os
|
||||
import struct
|
||||
import io
|
||||
import warnings
|
||||
|
||||
import six
|
||||
|
||||
from .protocol import Masker
|
||||
from netlib import tcp
|
||||
from netlib import utils
|
||||
|
||||
DEFAULT = object()
|
||||
|
||||
MAX_16_BIT_INT = (1 << 16)
|
||||
MAX_64_BIT_INT = (1 << 64)
|
||||
@ -33,9 +34,9 @@ class FrameHeader(object):
|
||||
rsv1=False,
|
||||
rsv2=False,
|
||||
rsv3=False,
|
||||
masking_key=DEFAULT,
|
||||
mask=DEFAULT,
|
||||
length_code=DEFAULT
|
||||
masking_key=None,
|
||||
mask=None,
|
||||
length_code=None
|
||||
):
|
||||
if not 0 <= opcode < 2 ** 4:
|
||||
raise ValueError("opcode must be 0-16")
|
||||
@ -46,18 +47,18 @@ class FrameHeader(object):
|
||||
self.rsv2 = rsv2
|
||||
self.rsv3 = rsv3
|
||||
|
||||
if length_code is DEFAULT:
|
||||
if length_code is None:
|
||||
self.length_code = self._make_length_code(self.payload_length)
|
||||
else:
|
||||
self.length_code = length_code
|
||||
|
||||
if mask is DEFAULT and masking_key is DEFAULT:
|
||||
if mask is None and masking_key is None:
|
||||
self.mask = False
|
||||
self.masking_key = ""
|
||||
elif mask is DEFAULT:
|
||||
self.masking_key = b""
|
||||
elif mask is None:
|
||||
self.mask = 1
|
||||
self.masking_key = masking_key
|
||||
elif masking_key is DEFAULT:
|
||||
elif masking_key is None:
|
||||
self.mask = mask
|
||||
self.masking_key = os.urandom(4)
|
||||
else:
|
||||
@ -81,7 +82,7 @@ class FrameHeader(object):
|
||||
else:
|
||||
return 127
|
||||
|
||||
def human_readable(self):
|
||||
def __repr__(self):
|
||||
vals = [
|
||||
"ws frame:",
|
||||
OPCODE.get_name(self.opcode, hex(self.opcode)).lower()
|
||||
@ -98,7 +99,11 @@ class FrameHeader(object):
|
||||
vals.append(" %s" % utils.pretty_size(self.payload_length))
|
||||
return "".join(vals)
|
||||
|
||||
def to_bytes(self):
|
||||
def human_readable(self):
|
||||
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
|
||||
return repr(self)
|
||||
|
||||
def __bytes__(self):
|
||||
first_byte = utils.setbit(0, 7, self.fin)
|
||||
first_byte = utils.setbit(first_byte, 6, self.rsv1)
|
||||
first_byte = utils.setbit(first_byte, 5, self.rsv2)
|
||||
@ -107,7 +112,7 @@ class FrameHeader(object):
|
||||
|
||||
second_byte = utils.setbit(self.length_code, 7, self.mask)
|
||||
|
||||
b = chr(first_byte) + chr(second_byte)
|
||||
b = six.int2byte(first_byte) + six.int2byte(second_byte)
|
||||
|
||||
if self.payload_length < 126:
|
||||
pass
|
||||
@ -119,10 +124,17 @@ class FrameHeader(object):
|
||||
# '!Q' = pack as 64 bit unsigned long long
|
||||
# add 8 bytes extended payload length
|
||||
b += struct.pack('!Q', self.payload_length)
|
||||
if self.masking_key is not None:
|
||||
if self.masking_key:
|
||||
b += self.masking_key
|
||||
return b
|
||||
|
||||
if six.PY2:
|
||||
__str__ = __bytes__
|
||||
|
||||
def to_bytes(self):
|
||||
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
|
||||
return bytes(self)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, fp):
|
||||
"""
|
||||
@ -154,7 +166,7 @@ class FrameHeader(object):
|
||||
if mask_bit == 1:
|
||||
masking_key = fp.safe_read(4)
|
||||
else:
|
||||
masking_key = None
|
||||
masking_key = False
|
||||
|
||||
return cls(
|
||||
fin=fin,
|
||||
@ -169,7 +181,9 @@ class FrameHeader(object):
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.to_bytes() == other.to_bytes()
|
||||
if isinstance(other, FrameHeader):
|
||||
return bytes(self) == bytes(other)
|
||||
return False
|
||||
|
||||
|
||||
class Frame(object):
|
||||
@ -200,7 +214,7 @@ class Frame(object):
|
||||
+---------------------------------------------------------------+
|
||||
"""
|
||||
|
||||
def __init__(self, payload="", **kwargs):
|
||||
def __init__(self, payload=b"", **kwargs):
|
||||
self.payload = payload
|
||||
kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
|
||||
self.header = FrameHeader(**kwargs)
|
||||
@ -216,7 +230,7 @@ class Frame(object):
|
||||
masking_key = os.urandom(4)
|
||||
else:
|
||||
mask_bit = 0
|
||||
masking_key = None
|
||||
masking_key = False
|
||||
|
||||
return cls(
|
||||
message,
|
||||
@ -234,28 +248,37 @@ class Frame(object):
|
||||
"""
|
||||
return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))
|
||||
|
||||
def human_readable(self):
|
||||
ret = self.header.human_readable()
|
||||
def __repr__(self):
|
||||
ret = repr(self.header)
|
||||
if self.payload:
|
||||
ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload)
|
||||
ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii")
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return self.header.human_readable()
|
||||
def human_readable(self):
|
||||
warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning)
|
||||
return repr(self)
|
||||
|
||||
def to_bytes(self):
|
||||
def __bytes__(self):
|
||||
"""
|
||||
Serialize the frame to wire format. Returns a string.
|
||||
"""
|
||||
b = self.header.to_bytes()
|
||||
b = bytes(self.header)
|
||||
if self.header.masking_key:
|
||||
b += Masker(self.header.masking_key)(self.payload)
|
||||
else:
|
||||
b += self.payload
|
||||
return b
|
||||
|
||||
if six.PY2:
|
||||
__str__ = __bytes__
|
||||
|
||||
def to_bytes(self):
|
||||
warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning)
|
||||
return bytes(self)
|
||||
|
||||
def to_file(self, writer):
|
||||
writer.write(self.to_bytes())
|
||||
warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning)
|
||||
writer.write(bytes(self))
|
||||
writer.flush()
|
||||
|
||||
@classmethod
|
||||
@ -286,4 +309,6 @@ class Frame(object):
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.to_bytes() == other.to_bytes()
|
||||
if isinstance(other, Frame):
|
||||
return bytes(self) == bytes(other)
|
||||
return False
|
@ -17,11 +17,12 @@ from __future__ import absolute_import
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import binascii
|
||||
import six
|
||||
from ..http import Headers
|
||||
from .. import utils
|
||||
|
||||
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||
websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||
VERSION = "13"
|
||||
|
||||
HEADER_WEBSOCKET_KEY = 'sec-websocket-key'
|
||||
@ -41,14 +42,21 @@ class Masker(object):
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.masks = [six.byte2int(byte) for byte in key]
|
||||
self.offset = 0
|
||||
|
||||
def mask(self, offset, data):
|
||||
result = ""
|
||||
for c in data:
|
||||
result += chr(ord(c) ^ self.masks[offset % 4])
|
||||
offset += 1
|
||||
result = bytearray(data)
|
||||
if six.PY2:
|
||||
for i in range(len(data)):
|
||||
result[i] ^= ord(self.key[offset % 4])
|
||||
offset += 1
|
||||
result = str(result)
|
||||
else:
|
||||
|
||||
for i in range(len(data)):
|
||||
result[i] ^= self.key[offset % 4]
|
||||
offset += 1
|
||||
result = bytes(result)
|
||||
return result
|
||||
|
||||
def __call__(self, data):
|
||||
@ -73,37 +81,35 @@ class WebsocketsProtocol(object):
|
||||
"""
|
||||
if not key:
|
||||
key = base64.b64encode(os.urandom(16)).decode('utf-8')
|
||||
return Headers([
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
(HEADER_WEBSOCKET_KEY, key),
|
||||
(HEADER_WEBSOCKET_VERSION, version)
|
||||
])
|
||||
return Headers(**{
|
||||
HEADER_WEBSOCKET_KEY: key,
|
||||
HEADER_WEBSOCKET_VERSION: version,
|
||||
"Connection": "Upgrade",
|
||||
"Upgrade": "websocket",
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def server_handshake_headers(self, key):
|
||||
"""
|
||||
The server response is a valid HTTP 101 response.
|
||||
"""
|
||||
return Headers(
|
||||
[
|
||||
('Connection', 'Upgrade'),
|
||||
('Upgrade', 'websocket'),
|
||||
(HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key))
|
||||
]
|
||||
)
|
||||
return Headers(**{
|
||||
HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key),
|
||||
"Connection": "Upgrade",
|
||||
"Upgrade": "websocket",
|
||||
})
|
||||
|
||||
|
||||
@classmethod
|
||||
def check_client_handshake(self, headers):
|
||||
if headers.get("upgrade") != "websocket":
|
||||
if headers.get("upgrade") != b"websocket":
|
||||
return
|
||||
return headers.get(HEADER_WEBSOCKET_KEY)
|
||||
|
||||
|
||||
@classmethod
|
||||
def check_server_handshake(self, headers):
|
||||
if headers.get("upgrade") != "websocket":
|
||||
if headers.get("upgrade") != b"websocket":
|
||||
return
|
||||
return headers.get(HEADER_WEBSOCKET_ACCEPT)
|
||||
|
||||
@ -111,5 +117,5 @@ class WebsocketsProtocol(object):
|
||||
@classmethod
|
||||
def create_server_nonce(self, client_nonce):
|
||||
return base64.b64encode(
|
||||
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
|
||||
binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest())
|
||||
)
|
||||
|
@ -1,14 +1,15 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
from io import BytesIO
|
||||
from io import BytesIO, StringIO
|
||||
import urllib
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import six
|
||||
from six.moves import urllib
|
||||
|
||||
from netlib.utils import always_bytes, native
|
||||
from . import http, tcp
|
||||
|
||||
|
||||
class ClientConn(object):
|
||||
|
||||
def __init__(self, address):
|
||||
@ -24,9 +25,10 @@ class Flow(object):
|
||||
|
||||
class Request(object):
|
||||
|
||||
def __init__(self, scheme, method, path, headers, body):
|
||||
def __init__(self, scheme, method, path, http_version, headers, body):
|
||||
self.scheme, self.method, self.path = scheme, method, path
|
||||
self.headers, self.body = headers, body
|
||||
self.http_version = http_version
|
||||
|
||||
|
||||
def date_time_string():
|
||||
@ -53,38 +55,38 @@ class WSGIAdaptor(object):
|
||||
self.app, self.domain, self.port, self.sversion = app, domain, port, sversion
|
||||
|
||||
def make_environ(self, flow, errsoc, **extra):
|
||||
if '?' in flow.request.path:
|
||||
path_info, query = flow.request.path.split('?', 1)
|
||||
path = native(flow.request.path)
|
||||
if '?' in path:
|
||||
path_info, query = native(path).split('?', 1)
|
||||
else:
|
||||
path_info = flow.request.path
|
||||
path_info = path
|
||||
query = ''
|
||||
environ = {
|
||||
'wsgi.version': (1, 0),
|
||||
'wsgi.url_scheme': flow.request.scheme,
|
||||
'wsgi.url_scheme': native(flow.request.scheme),
|
||||
'wsgi.input': BytesIO(flow.request.body or b""),
|
||||
'wsgi.errors': errsoc,
|
||||
'wsgi.multithread': True,
|
||||
'wsgi.multiprocess': False,
|
||||
'wsgi.run_once': False,
|
||||
'SERVER_SOFTWARE': self.sversion,
|
||||
'REQUEST_METHOD': flow.request.method,
|
||||
'REQUEST_METHOD': native(flow.request.method),
|
||||
'SCRIPT_NAME': '',
|
||||
'PATH_INFO': urllib.unquote(path_info),
|
||||
'PATH_INFO': urllib.parse.unquote(path_info),
|
||||
'QUERY_STRING': query,
|
||||
'CONTENT_TYPE': flow.request.headers.get('Content-Type', ''),
|
||||
'CONTENT_LENGTH': flow.request.headers.get('Content-Length', ''),
|
||||
'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', '')),
|
||||
'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', '')),
|
||||
'SERVER_NAME': self.domain,
|
||||
'SERVER_PORT': str(self.port),
|
||||
# FIXME: We need to pick up the protocol read from the request.
|
||||
'SERVER_PROTOCOL': "HTTP/1.1",
|
||||
'SERVER_PROTOCOL': native(flow.request.http_version),
|
||||
}
|
||||
environ.update(extra)
|
||||
if flow.client_conn.address:
|
||||
environ["REMOTE_ADDR"], environ[
|
||||
"REMOTE_PORT"] = flow.client_conn.address()
|
||||
environ["REMOTE_ADDR"] = native(flow.client_conn.address.host)
|
||||
environ["REMOTE_PORT"] = flow.client_conn.address.port
|
||||
|
||||
for key, value in flow.request.headers.items():
|
||||
key = 'HTTP_' + key.upper().replace('-', '_')
|
||||
key = 'HTTP_' + native(key).upper().replace('-', '_')
|
||||
if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
|
||||
environ[key] = value
|
||||
return environ
|
||||
@ -99,7 +101,7 @@ class WSGIAdaptor(object):
|
||||
<h1>Internal Server Error</h1>
|
||||
<pre>%s"</pre>
|
||||
</html>
|
||||
""".strip() % s
|
||||
""".strip() % s.encode()
|
||||
if not headers_sent:
|
||||
soc.write(b"HTTP/1.1 500 Internal Server Error\r\n")
|
||||
soc.write(b"Content-Type: text/html\r\n")
|
||||
@ -117,7 +119,7 @@ class WSGIAdaptor(object):
|
||||
|
||||
def write(data):
|
||||
if not state["headers_sent"]:
|
||||
soc.write(b"HTTP/1.1 %s\r\n" % state["status"])
|
||||
soc.write(b"HTTP/1.1 %s\r\n" % state["status"].encode())
|
||||
headers = state["headers"]
|
||||
if 'server' not in headers:
|
||||
headers["Server"] = self.sversion
|
||||
@ -132,18 +134,17 @@ class WSGIAdaptor(object):
|
||||
|
||||
def start_response(status, headers, exc_info=None):
|
||||
if exc_info:
|
||||
try:
|
||||
if state["headers_sent"]:
|
||||
six.reraise(*exc_info)
|
||||
finally:
|
||||
exc_info = None
|
||||
if state["headers_sent"]:
|
||||
six.reraise(*exc_info)
|
||||
elif state["status"]:
|
||||
raise AssertionError('Response already started')
|
||||
state["status"] = status
|
||||
state["headers"] = http.Headers(headers)
|
||||
return write
|
||||
state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k,v in headers])
|
||||
if exc_info:
|
||||
self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2]))
|
||||
state["headers_sent"] = True
|
||||
|
||||
errs = BytesIO()
|
||||
errs = six.BytesIO()
|
||||
try:
|
||||
dataiter = self.app(
|
||||
self.make_environ(request, errs, **env), start_response
|
||||
@ -155,7 +156,7 @@ class WSGIAdaptor(object):
|
||||
except Exception as e:
|
||||
try:
|
||||
s = traceback.format_exc()
|
||||
errs.write(s)
|
||||
errs.write(s.encode("utf-8", "replace"))
|
||||
self.error_page(soc, state["headers_sent"], s)
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
|
@ -58,20 +58,20 @@ class TestRequest(object):
|
||||
req = tutils.treq()
|
||||
req.headers["Accept-Encoding"] = "foobar"
|
||||
req.anticomp()
|
||||
assert req.headers["Accept-Encoding"] == "identity"
|
||||
assert req.headers["Accept-Encoding"] == b"identity"
|
||||
|
||||
def test_constrain_encoding(self):
|
||||
req = tutils.treq()
|
||||
req.headers["Accept-Encoding"] = "identity, gzip, foo"
|
||||
req.constrain_encoding()
|
||||
assert "foo" not in req.headers["Accept-Encoding"]
|
||||
assert b"foo" not in req.headers["Accept-Encoding"]
|
||||
|
||||
def test_update_host(self):
|
||||
req = tutils.treq()
|
||||
req.headers["Host"] = ""
|
||||
req.host = "foobar"
|
||||
req.update_host_header()
|
||||
assert req.headers["Host"] == "foobar"
|
||||
assert req.headers["Host"] == b"foobar"
|
||||
|
||||
def test_get_form(self):
|
||||
req = tutils.treq()
|
||||
@ -132,7 +132,7 @@ class TestRequest(object):
|
||||
|
||||
def test_set_path_components(self):
|
||||
req = tutils.treq()
|
||||
req.set_path_components(["foo", "bar"])
|
||||
req.set_path_components([b"foo", b"bar"])
|
||||
# TODO: add meaningful assertions
|
||||
|
||||
def test_get_query(self):
|
||||
@ -140,7 +140,7 @@ class TestRequest(object):
|
||||
assert req.get_query().lst == []
|
||||
|
||||
req.url = "http://localhost:80/foo?bar=42"
|
||||
assert req.get_query().lst == [("bar", "42")]
|
||||
assert req.get_query().lst == [(b"bar", b"42")]
|
||||
|
||||
def test_set_query(self):
|
||||
req = tutils.treq()
|
||||
@ -167,12 +167,12 @@ class TestRequest(object):
|
||||
def test_pretty_url(self):
|
||||
req = tutils.treq()
|
||||
req.form_out = "authority"
|
||||
assert req.pretty_url(True) == "address:22"
|
||||
assert req.pretty_url(False) == "address:22"
|
||||
assert req.pretty_url(True) == b"address:22"
|
||||
assert req.pretty_url(False) == b"address:22"
|
||||
|
||||
req.form_out = "relative"
|
||||
assert req.pretty_url(True) == "http://address:22/path"
|
||||
assert req.pretty_url(False) == "http://address:22/path"
|
||||
assert req.pretty_url(True) == b"http://address:22/path"
|
||||
assert req.pretty_url(False) == b"http://address:22/path"
|
||||
|
||||
def test_get_cookies_none(self):
|
||||
headers = Headers()
|
||||
@ -213,11 +213,11 @@ class TestRequest(object):
|
||||
|
||||
def test_set_url(self):
|
||||
r = tutils.treq(form_in="absolute")
|
||||
r.url = "https://otheraddress:42/ORLY"
|
||||
assert r.scheme == "https"
|
||||
assert r.host == "otheraddress"
|
||||
r.url = b"https://otheraddress:42/ORLY"
|
||||
assert r.scheme == b"https"
|
||||
assert r.host == b"otheraddress"
|
||||
assert r.port == 42
|
||||
assert r.path == "/ORLY"
|
||||
assert r.path == b"/ORLY"
|
||||
|
||||
try:
|
||||
r.url = "//localhost:80/foo@bar"
|
||||
@ -374,8 +374,8 @@ class TestResponse(object):
|
||||
def test_get_cookies_twocookies(self):
|
||||
resp = tutils.tresp()
|
||||
resp.headers = Headers([
|
||||
["Set-Cookie", "cookiename=cookievalue"],
|
||||
["Set-Cookie", "othercookie=othervalue"]
|
||||
[b"Set-Cookie", b"cookiename=cookievalue"],
|
||||
[b"Set-Cookie", b"othercookie=othervalue"]
|
||||
])
|
||||
result = resp.get_cookies()
|
||||
assert len(result) == 2
|
||||
@ -399,8 +399,8 @@ class TestHeaders(object):
|
||||
def _2host(self):
|
||||
return Headers(
|
||||
[
|
||||
["Host", "example.com"],
|
||||
["host", "example.org"]
|
||||
[b"Host", b"example.com"],
|
||||
[b"host", b"example.org"]
|
||||
]
|
||||
)
|
||||
|
||||
@ -408,37 +408,37 @@ class TestHeaders(object):
|
||||
headers = Headers()
|
||||
assert len(headers) == 0
|
||||
|
||||
headers = Headers([["Host", "example.com"]])
|
||||
headers = Headers([[b"Host", b"example.com"]])
|
||||
assert len(headers) == 1
|
||||
assert headers["Host"] == "example.com"
|
||||
assert headers["Host"] == b"example.com"
|
||||
|
||||
headers = Headers(Host="example.com")
|
||||
assert len(headers) == 1
|
||||
assert headers["Host"] == "example.com"
|
||||
assert headers["Host"] == b"example.com"
|
||||
|
||||
headers = Headers(
|
||||
[["Host", "invalid"]],
|
||||
[[b"Host", b"invalid"]],
|
||||
Host="example.com"
|
||||
)
|
||||
assert len(headers) == 1
|
||||
assert headers["Host"] == "example.com"
|
||||
assert headers["Host"] == b"example.com"
|
||||
|
||||
headers = Headers(
|
||||
[["Host", "invalid"], ["Accept", "text/plain"]],
|
||||
[[b"Host", b"invalid"], [b"Accept", b"text/plain"]],
|
||||
Host="example.com"
|
||||
)
|
||||
assert len(headers) == 2
|
||||
assert headers["Host"] == "example.com"
|
||||
assert headers["Accept"] == "text/plain"
|
||||
assert headers["Host"] == b"example.com"
|
||||
assert headers["Accept"] == b"text/plain"
|
||||
|
||||
def test_getitem(self):
|
||||
headers = Headers(Host="example.com")
|
||||
assert headers["Host"] == "example.com"
|
||||
assert headers["host"] == "example.com"
|
||||
assert headers["Host"] == b"example.com"
|
||||
assert headers["host"] == b"example.com"
|
||||
tutils.raises(KeyError, headers.__getitem__, "Accept")
|
||||
|
||||
headers = self._2host()
|
||||
assert headers["Host"] == "example.com, example.org"
|
||||
assert headers["Host"] == b"example.com, example.org"
|
||||
|
||||
def test_str(self):
|
||||
headers = Headers(Host="example.com")
|
||||
@ -458,12 +458,12 @@ class TestHeaders(object):
|
||||
headers["Host"] = "example.com"
|
||||
assert "Host" in headers
|
||||
assert "host" in headers
|
||||
assert headers["Host"] == "example.com"
|
||||
assert headers["Host"] == b"example.com"
|
||||
|
||||
headers["host"] = "example.org"
|
||||
assert "Host" in headers
|
||||
assert "host" in headers
|
||||
assert headers["Host"] == "example.org"
|
||||
assert headers["Host"] == b"example.org"
|
||||
|
||||
headers["accept"] = "text/plain"
|
||||
assert len(headers) == 2
|
||||
@ -494,12 +494,10 @@ class TestHeaders(object):
|
||||
|
||||
def test_keys(self):
|
||||
headers = Headers(Host="example.com")
|
||||
assert len(headers.keys()) == 1
|
||||
assert headers.keys()[0] == "Host"
|
||||
assert list(headers.keys()) == [b"Host"]
|
||||
|
||||
headers = self._2host()
|
||||
assert len(headers.keys()) == 1
|
||||
assert headers.keys()[0] == "Host"
|
||||
assert list(headers.keys()) == [b"Host"]
|
||||
|
||||
def test_eq_ne(self):
|
||||
headers1 = Headers(Host="example.com")
|
||||
@ -516,7 +514,7 @@ class TestHeaders(object):
|
||||
|
||||
def test_get_all(self):
|
||||
headers = self._2host()
|
||||
assert headers.get_all("host") == ["example.com", "example.org"]
|
||||
assert headers.get_all("host") == [b"example.com", b"example.org"]
|
||||
assert headers.get_all("accept") == []
|
||||
|
||||
def test_set_all(self):
|
||||
@ -527,10 +525,10 @@ class TestHeaders(object):
|
||||
|
||||
headers = self._2host()
|
||||
headers.set_all("Host", ["example.org"])
|
||||
assert headers["host"] == "example.org"
|
||||
assert headers["host"] == b"example.org"
|
||||
|
||||
headers.set_all("Host", ["example.org", "example.net"])
|
||||
assert headers["host"] == "example.org, example.net"
|
||||
assert headers["host"] == b"example.org, example.net"
|
||||
|
||||
def test_state(self):
|
||||
headers = self._2host()
|
||||
|
@ -4,8 +4,6 @@ from netlib import encoding
|
||||
def test_identity():
|
||||
assert b"string" == encoding.decode("identity", b"string")
|
||||
assert b"string" == encoding.encode("identity", b"string")
|
||||
assert b"string" == encoding.encode(b"identity", b"string")
|
||||
assert b"string" == encoding.decode(b"identity", b"string")
|
||||
assert not encoding.encode("nonexistent", b"string")
|
||||
assert not encoding.decode("nonexistent encoding", b"string")
|
||||
|
||||
|
@ -5,8 +5,8 @@ from netlib.http import Headers
|
||||
|
||||
|
||||
def tflow():
|
||||
headers = Headers(test="value")
|
||||
req = wsgi.Request("http", "GET", "/", headers, "")
|
||||
headers = Headers(test=b"value")
|
||||
req = wsgi.Request("http", "GET", "/", "HTTP/1.1", headers, "")
|
||||
return wsgi.Flow(("127.0.0.1", 8888), req)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ class TestApp:
|
||||
status = '200 OK'
|
||||
response_headers = [('Content-type', 'text/plain')]
|
||||
start_response(status, response_headers)
|
||||
return ['Hello', ' world!\n']
|
||||
return [b'Hello', b' world!\n']
|
||||
|
||||
|
||||
class TestWSGI:
|
||||
@ -47,8 +47,8 @@ class TestWSGI:
|
||||
assert not err
|
||||
|
||||
val = wfile.getvalue()
|
||||
assert "Hello world" in val
|
||||
assert "Server:" in val
|
||||
assert b"Hello world" in val
|
||||
assert b"Server:" in val
|
||||
|
||||
def _serve(self, app):
|
||||
w = wsgi.WSGIAdaptor(app, "foo", 80, "version")
|
||||
@ -77,7 +77,7 @@ class TestWSGI:
|
||||
response_headers = [('Content-type', 'text/plain')]
|
||||
start_response(status, response_headers)
|
||||
start_response(status, response_headers)
|
||||
assert "Internal Server Error" in self._serve(app)
|
||||
assert b"Internal Server Error" in self._serve(app)
|
||||
|
||||
def test_serve_single_err(self):
|
||||
def app(environ, start_response):
|
||||
@ -88,7 +88,8 @@ class TestWSGI:
|
||||
status = '200 OK'
|
||||
response_headers = [('Content-type', 'text/plain')]
|
||||
start_response(status, response_headers, ei)
|
||||
assert "Internal Server Error" in self._serve(app)
|
||||
yield b""
|
||||
assert b"Internal Server Error" in self._serve(app)
|
||||
|
||||
def test_serve_double_err(self):
|
||||
def app(environ, start_response):
|
||||
@ -99,7 +100,7 @@ class TestWSGI:
|
||||
status = '200 OK'
|
||||
response_headers = [('Content-type', 'text/plain')]
|
||||
start_response(status, response_headers)
|
||||
yield "aaa"
|
||||
yield b"aaa"
|
||||
start_response(status, response_headers, ei)
|
||||
yield "bbb"
|
||||
assert "Internal Server Error" in self._serve(app)
|
||||
yield b"bbb"
|
||||
assert b"Internal Server Error" in self._serve(app)
|
||||
|
@ -41,7 +41,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
key = self.protocol.check_client_handshake(req.headers)
|
||||
|
||||
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
self.wfile.write(preamble.encode() + b"\r\n")
|
||||
headers = self.protocol.server_handshake_headers(key)
|
||||
self.wfile.write(str(headers) + "\r\n")
|
||||
self.wfile.flush()
|
||||
@ -62,11 +62,11 @@ class WebSocketsClient(tcp.TCPClient):
|
||||
def connect(self):
|
||||
super(WebSocketsClient, self).connect()
|
||||
|
||||
preamble = 'GET / HTTP/1.1'
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
preamble = b'GET / HTTP/1.1'
|
||||
self.wfile.write(preamble + b"\r\n")
|
||||
headers = self.protocol.client_handshake_headers()
|
||||
self.client_nonce = headers["sec-websocket-key"]
|
||||
self.wfile.write(str(headers) + "\r\n")
|
||||
self.wfile.write(bytes(headers) + b"\r\n")
|
||||
self.wfile.flush()
|
||||
|
||||
resp = read_response(self.rfile, treq(method="GET"))
|
||||
@ -101,7 +101,7 @@ class TestWebSockets(tservers.ServerTestBase):
|
||||
assert response == msg
|
||||
|
||||
def test_simple_echo(self):
|
||||
self.echo("hello I'm the client")
|
||||
self.echo(b"hello I'm the client")
|
||||
|
||||
def test_frame_sizes(self):
|
||||
# length can fit in the the 7 bit payload length
|
||||
@ -161,10 +161,10 @@ class BadHandshakeHandler(WebSocketsEchoHandler):
|
||||
client_hs = read_request(self.rfile)
|
||||
self.protocol.check_client_handshake(client_hs.headers)
|
||||
|
||||
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
headers = self.protocol.server_handshake_headers("malformed key")
|
||||
self.wfile.write(str(headers) + "\r\n")
|
||||
preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
|
||||
self.wfile.write(preamble.encode())
|
||||
headers = self.protocol.server_handshake_headers(b"malformed key")
|
||||
self.wfile.write(bytes(headers) + b"\r\n")
|
||||
self.wfile.flush()
|
||||
self.handshake_done = True
|
||||
|
||||
@ -180,7 +180,7 @@ class TestBadHandshake(tservers.ServerTestBase):
|
||||
def test(self):
|
||||
client = WebSocketsClient(("127.0.0.1", self.port))
|
||||
client.connect()
|
||||
client.send_message("hello")
|
||||
client.send_message(b"hello")
|
||||
|
||||
|
||||
class TestFrameHeader:
|
||||
@ -188,8 +188,7 @@ class TestFrameHeader:
|
||||
def test_roundtrip(self):
|
||||
def round(*args, **kwargs):
|
||||
f = websockets.FrameHeader(*args, **kwargs)
|
||||
bytes = f.to_bytes()
|
||||
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
|
||||
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
|
||||
assert f == f2
|
||||
round()
|
||||
round(fin=1)
|
||||
@ -201,11 +200,11 @@ class TestFrameHeader:
|
||||
round(payload_length=1000)
|
||||
round(payload_length=10000)
|
||||
round(opcode=websockets.OPCODE.PING)
|
||||
round(masking_key="test")
|
||||
round(masking_key=b"test")
|
||||
|
||||
def test_human_readable(self):
|
||||
f = websockets.FrameHeader(
|
||||
masking_key="test",
|
||||
masking_key=b"test",
|
||||
fin=True,
|
||||
payload_length=10
|
||||
)
|
||||
@ -214,23 +213,23 @@ class TestFrameHeader:
|
||||
assert f.human_readable()
|
||||
|
||||
def test_funky(self):
|
||||
f = websockets.FrameHeader(masking_key="test", mask=False)
|
||||
f = websockets.FrameHeader(masking_key=b"test", mask=False)
|
||||
bytes = f.to_bytes()
|
||||
f2 = websockets.FrameHeader.from_file(tutils.treader(bytes))
|
||||
assert not f2.mask
|
||||
|
||||
def test_violations(self):
|
||||
tutils.raises("opcode", websockets.FrameHeader, opcode=17)
|
||||
tutils.raises("masking key", websockets.FrameHeader, masking_key="x")
|
||||
tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
|
||||
|
||||
def test_automask(self):
|
||||
f = websockets.FrameHeader(mask=True)
|
||||
assert f.masking_key
|
||||
|
||||
f = websockets.FrameHeader(masking_key="foob")
|
||||
f = websockets.FrameHeader(masking_key=b"foob")
|
||||
assert f.mask
|
||||
|
||||
f = websockets.FrameHeader(masking_key="foob", mask=0)
|
||||
f = websockets.FrameHeader(masking_key=b"foob", mask=0)
|
||||
assert not f.mask
|
||||
assert f.masking_key
|
||||
|
||||
@ -240,31 +239,31 @@ class TestFrame:
|
||||
def test_roundtrip(self):
|
||||
def round(*args, **kwargs):
|
||||
f = websockets.Frame(*args, **kwargs)
|
||||
bytes = f.to_bytes()
|
||||
f2 = websockets.Frame.from_file(tutils.treader(bytes))
|
||||
raw = bytes(f)
|
||||
f2 = websockets.Frame.from_file(tutils.treader(raw))
|
||||
assert f == f2
|
||||
round("test")
|
||||
round("test", fin=1)
|
||||
round("test", rsv1=1)
|
||||
round("test", opcode=websockets.OPCODE.PING)
|
||||
round("test", masking_key="test")
|
||||
round(b"test")
|
||||
round(b"test", fin=1)
|
||||
round(b"test", rsv1=1)
|
||||
round(b"test", opcode=websockets.OPCODE.PING)
|
||||
round(b"test", masking_key=b"test")
|
||||
|
||||
def test_human_readable(self):
|
||||
f = websockets.Frame()
|
||||
assert f.human_readable()
|
||||
assert repr(f)
|
||||
|
||||
|
||||
def test_masker():
|
||||
tests = [
|
||||
["a"],
|
||||
["four"],
|
||||
["fourf"],
|
||||
["fourfive"],
|
||||
["a", "aasdfasdfa", "asdf"],
|
||||
["a" * 50, "aasdfasdfa", "asdf"],
|
||||
[b"a"],
|
||||
[b"four"],
|
||||
[b"fourf"],
|
||||
[b"fourfive"],
|
||||
[b"a", b"aasdfasdfa", b"asdf"],
|
||||
[b"a" * 50, b"aasdfasdfa", b"asdf"],
|
||||
]
|
||||
for i in tests:
|
||||
m = websockets.Masker("abcd")
|
||||
data = "".join([m(t) for t in i])
|
||||
data2 = websockets.Masker("abcd")(data)
|
||||
assert data2 == "".join(i)
|
||||
m = websockets.Masker(b"abcd")
|
||||
data = b"".join([m(t) for t in i])
|
||||
data2 = websockets.Masker(b"abcd")(data)
|
||||
assert data2 == b"".join(i)
|
||||
|
Loading…
Reference in New Issue
Block a user