minor encoding fixes

- native() -> always_str()
  The old function name does not make sense on Python 3 only.
- Inline utility functions in message.py.
This commit is contained in:
Maximilian Hils 2017-01-06 00:31:06 +01:00
parent af194918cf
commit 042261266f
10 changed files with 68 additions and 63 deletions

View File

@ -93,7 +93,7 @@ def convert_100_200(data):
def _convert_dict_keys(o: Any) -> Any: def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict): if isinstance(o, dict):
return {strutils.native(k): _convert_dict_keys(v) for k, v in o.items()} return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
else: else:
return o return o
@ -103,7 +103,7 @@ def _convert_dict_vals(o: dict, values_to_convert: dict) -> dict:
if not o or k not in o: if not o or k not in o:
continue continue
if v is True: if v is True:
o[k] = strutils.native(o[k]) o[k] = strutils.always_str(o[k])
else: else:
_convert_dict_vals(o[k], v) _convert_dict_vals(o[k], v)
return o return o

View File

@ -7,15 +7,6 @@ from mitmproxy.types import serializable
from mitmproxy.net.http import headers from mitmproxy.net.http import headers
# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
def _native(x):
return x.decode("utf-8", "surrogateescape")
def _always_bytes(x):
return strutils.always_bytes(x, "utf-8", "surrogateescape")
class MessageData(serializable.Serializable): class MessageData(serializable.Serializable):
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, MessageData): if isinstance(other, MessageData):
@ -142,11 +133,11 @@ class Message(serializable.Serializable):
""" """
Version string, e.g. "HTTP/1.1" Version string, e.g. "HTTP/1.1"
""" """
return _native(self.data.http_version) return self.data.http_version.decode("utf-8", "surrogateescape")
@http_version.setter @http_version.setter
def http_version(self, http_version): def http_version(self, http_version):
self.data.http_version = _always_bytes(http_version) self.data.http_version = strutils.always_bytes(http_version, "utf-8", "surrogateescape")
@property @property
def timestamp_start(self): def timestamp_start(self):

View File

@ -115,24 +115,24 @@ class Request(message.Message):
""" """
HTTP request method, e.g. "GET". HTTP request method, e.g. "GET".
""" """
return message._native(self.data.method).upper() return self.data.method.decode("utf-8", "surrogateescape").upper()
@method.setter @method.setter
def method(self, method): def method(self, method):
self.data.method = message._always_bytes(method) self.data.method = strutils.always_bytes(method, "utf-8", "surrogateescape")
@property @property
def scheme(self): def scheme(self):
""" """
HTTP request scheme, which should be "http" or "https". HTTP request scheme, which should be "http" or "https".
""" """
if not self.data.scheme: if self.data.scheme is None:
return self.data.scheme return None
return message._native(self.data.scheme) return self.data.scheme.decode("utf-8", "surrogateescape")
@scheme.setter @scheme.setter
def scheme(self, scheme): def scheme(self, scheme):
self.data.scheme = message._always_bytes(scheme) self.data.scheme = strutils.always_bytes(scheme, "utf-8", "surrogateescape")
@property @property
def host(self): def host(self):
@ -190,11 +190,11 @@ class Request(message.Message):
if self.data.path is None: if self.data.path is None:
return None return None
else: else:
return message._native(self.data.path) return self.data.path.decode("utf-8", "surrogateescape")
@path.setter @path.setter
def path(self, path): def path(self, path):
self.data.path = message._always_bytes(path) self.data.path = strutils.always_bytes(path, "utf-8", "surrogateescape")
@property @property
def url(self): def url(self):

View File

@ -538,7 +538,7 @@ class _Connection:
self.ssl_verification_error = exceptions.InvalidCertificateException( self.ssl_verification_error = exceptions.InvalidCertificateException(
"Certificate Verification Error for {}: {} (errno: {}, depth: {})".format( "Certificate Verification Error for {}: {} (errno: {}, depth: {})".format(
sni, sni,
strutils.native(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"), strutils.always_str(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"),
errno, errno,
err_depth err_depth
) )

View File

@ -57,38 +57,38 @@ class WSGIAdaptor:
Raises: Raises:
ValueError, if the content-encoding is invalid. ValueError, if the content-encoding is invalid.
""" """
path = strutils.native(flow.request.path, "latin-1") path = strutils.always_str(flow.request.path, "latin-1")
if '?' in path: if '?' in path:
path_info, query = strutils.native(path, "latin-1").split('?', 1) path_info, query = strutils.always_str(path, "latin-1").split('?', 1)
else: else:
path_info = path path_info = path
query = '' query = ''
environ = { environ = {
'wsgi.version': (1, 0), 'wsgi.version': (1, 0),
'wsgi.url_scheme': strutils.native(flow.request.scheme, "latin-1"), 'wsgi.url_scheme': strutils.always_str(flow.request.scheme, "latin-1"),
'wsgi.input': io.BytesIO(flow.request.content or b""), 'wsgi.input': io.BytesIO(flow.request.content or b""),
'wsgi.errors': errsoc, 'wsgi.errors': errsoc,
'wsgi.multithread': True, 'wsgi.multithread': True,
'wsgi.multiprocess': False, 'wsgi.multiprocess': False,
'wsgi.run_once': False, 'wsgi.run_once': False,
'SERVER_SOFTWARE': self.sversion, 'SERVER_SOFTWARE': self.sversion,
'REQUEST_METHOD': strutils.native(flow.request.method, "latin-1"), 'REQUEST_METHOD': strutils.always_str(flow.request.method, "latin-1"),
'SCRIPT_NAME': '', 'SCRIPT_NAME': '',
'PATH_INFO': urllib.parse.unquote(path_info), 'PATH_INFO': urllib.parse.unquote(path_info),
'QUERY_STRING': query, 'QUERY_STRING': query,
'CONTENT_TYPE': strutils.native(flow.request.headers.get('Content-Type', ''), "latin-1"), 'CONTENT_TYPE': strutils.always_str(flow.request.headers.get('Content-Type', ''), "latin-1"),
'CONTENT_LENGTH': strutils.native(flow.request.headers.get('Content-Length', ''), "latin-1"), 'CONTENT_LENGTH': strutils.always_str(flow.request.headers.get('Content-Length', ''), "latin-1"),
'SERVER_NAME': self.domain, 'SERVER_NAME': self.domain,
'SERVER_PORT': str(self.port), 'SERVER_PORT': str(self.port),
'SERVER_PROTOCOL': strutils.native(flow.request.http_version, "latin-1"), 'SERVER_PROTOCOL': strutils.always_str(flow.request.http_version, "latin-1"),
} }
environ.update(extra) environ.update(extra)
if flow.client_conn.address: if flow.client_conn.address:
environ["REMOTE_ADDR"] = strutils.native(flow.client_conn.address.host, "latin-1") environ["REMOTE_ADDR"] = strutils.always_str(flow.client_conn.address.host, "latin-1")
environ["REMOTE_PORT"] = flow.client_conn.address.port environ["REMOTE_PORT"] = flow.client_conn.address.port
for key, value in flow.request.headers.items(): for key, value in flow.request.headers.items():
key = 'HTTP_' + strutils.native(key, "latin-1").upper().replace('-', '_') key = 'HTTP_' + strutils.always_str(key, "latin-1").upper().replace('-', '_')
if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
environ[key] = value environ[key] = value
return environ return environ

View File

@ -1,28 +1,28 @@
import re import re
import codecs import codecs
from typing import AnyStr, Optional
def always_bytes(unicode_or_bytes, *encode_args): def always_bytes(str_or_bytes: Optional[AnyStr], *encode_args) -> Optional[bytes]:
if isinstance(unicode_or_bytes, str): if isinstance(str_or_bytes, bytes) or str_or_bytes is None:
return unicode_or_bytes.encode(*encode_args) return str_or_bytes
elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None: elif isinstance(str_or_bytes, str):
return unicode_or_bytes return str_or_bytes.encode(*encode_args)
else: else:
raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__)) raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))
def native(s, *encoding_opts): def always_str(str_or_bytes: Optional[AnyStr], *decode_args) -> Optional[str]:
""" """
Convert :py:class:`bytes` or :py:class:`unicode` to the native Returns,
:py:class:`str` type, using latin1 encoding if conversion is necessary. str_or_bytes unmodified, if
https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types
""" """
if not isinstance(s, (bytes, str)): if isinstance(str_or_bytes, str) or str_or_bytes is None:
raise TypeError("%r is neither bytes nor unicode" % s) return str_or_bytes
if isinstance(s, bytes): elif isinstance(str_or_bytes, bytes):
return s.decode(*encoding_opts) return str_or_bytes.decode(*decode_args)
return s else:
raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))
# Translate control characters to "safe" characters. This implementation initially # Translate control characters to "safe" characters. This implementation initially
@ -135,7 +135,7 @@ def hexdump(s):
part = s[i:i + 16] part = s[i:i + 16]
x = " ".join("{:0=2x}".format(i) for i in part) x = " ".join("{:0=2x}".format(i) for i in part)
x = x.ljust(47) # 16*2 + 15 x = x.ljust(47) # 16*2 + 15
part_repr = native(escape_control_characters( part_repr = always_str(escape_control_characters(
part.decode("ascii", "replace").replace(u"\ufffd", u"."), part.decode("ascii", "replace").replace(u"\ufffd", u"."),
False False
)) ))

View File

@ -61,7 +61,7 @@ class LogCtx:
for line in strutils.hexdump(data): for line in strutils.hexdump(data):
self("\t%s %s %s" % line) self("\t%s %s %s" % line)
else: else:
data = strutils.native( data = strutils.always_str(
strutils.escape_control_characters( strutils.escape_control_characters(
data data
.decode("ascii", "replace") .decode("ascii", "replace")

View File

@ -44,7 +44,7 @@ class SSLInfo:
def __str__(self): def __str__(self):
parts = [ parts = [
"Application Layer Protocol: %s" % strutils.native(self.alp, "utf8"), "Application Layer Protocol: %s" % strutils.always_str(self.alp, "utf8"),
"Cipher: %s, %s bit, %s" % self.cipher, "Cipher: %s, %s bit, %s" % self.cipher,
"SSL certificate chain:" "SSL certificate chain:"
] ]
@ -53,24 +53,24 @@ class SSLInfo:
parts.append("\tSubject: ") parts.append("\tSubject: ")
for cn in i.get_subject().get_components(): for cn in i.get_subject().get_components():
parts.append("\t\t%s=%s" % ( parts.append("\t\t%s=%s" % (
strutils.native(cn[0], "utf8"), strutils.always_str(cn[0], "utf8"),
strutils.native(cn[1], "utf8")) strutils.always_str(cn[1], "utf8"))
) )
parts.append("\tIssuer: ") parts.append("\tIssuer: ")
for cn in i.get_issuer().get_components(): for cn in i.get_issuer().get_components():
parts.append("\t\t%s=%s" % ( parts.append("\t\t%s=%s" % (
strutils.native(cn[0], "utf8"), strutils.always_str(cn[0], "utf8"),
strutils.native(cn[1], "utf8")) strutils.always_str(cn[1], "utf8"))
) )
parts.extend( parts.extend(
[ [
"\tVersion: %s" % i.get_version(), "\tVersion: %s" % i.get_version(),
"\tValidity: %s - %s" % ( "\tValidity: %s - %s" % (
strutils.native(i.get_notBefore(), "utf8"), strutils.always_str(i.get_notBefore(), "utf8"),
strutils.native(i.get_notAfter(), "utf8") strutils.always_str(i.get_notAfter(), "utf8")
), ),
"\tSerial: %s" % i.get_serial_number(), "\tSerial: %s" % i.get_serial_number(),
"\tAlgorithm: %s" % strutils.native(i.get_signature_algorithm(), "utf8") "\tAlgorithm: %s" % strutils.always_str(i.get_signature_algorithm(), "utf8")
] ]
) )
pk = i.get_pubkey() pk = i.get_pubkey()
@ -82,7 +82,7 @@ class SSLInfo:
parts.append("\tPubkey: %s bit %s" % (pk.bits(), t)) parts.append("\tPubkey: %s bit %s" % (pk.bits(), t))
s = certs.SSLCert(i) s = certs.SSLCert(i)
if s.altnames: if s.altnames:
parts.append("\tSANs: %s" % " ".join(strutils.native(n, "utf8") for n in s.altnames)) parts.append("\tSANs: %s" % " ".join(strutils.always_str(n, "utf8") for n in s.altnames))
return "\n".join(parts) return "\n".join(parts)

View File

@ -55,7 +55,20 @@ class TestResponseCore:
_test_passthrough_attr(tresp(), "status_code") _test_passthrough_attr(tresp(), "status_code")
def test_reason(self): def test_reason(self):
_test_decoded_attr(tresp(), "reason") resp = tresp()
assert resp.reason == "OK"
resp.reason = "ABC"
assert resp.data.reason == b"ABC"
resp.reason = b"DEF"
assert resp.data.reason == b"DEF"
resp.reason = None
assert resp.data.reason is None
resp.data.reason = b'cr\xe9e'
assert resp.reason == "crée"
class TestResponseUtils: class TestResponseUtils:

View File

@ -11,11 +11,12 @@ def test_always_bytes():
strutils.always_bytes(42, "ascii") strutils.always_bytes(42, "ascii")
def test_native(): def test_always_str():
with tutils.raises(TypeError): with tutils.raises(TypeError):
strutils.native(42) strutils.always_str(42)
assert strutils.native(u"foo") == u"foo" assert strutils.always_str("foo") == "foo"
assert strutils.native(b"foo") == u"foo" assert strutils.always_str(b"foo") == "foo"
assert strutils.always_str(None) is None
def test_escape_control_characters(): def test_escape_control_characters():