mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
minor encoding fixes
- native() -> always_str() The old function name does not make sense on Python 3 only. - Inline utility functions in message.py.
This commit is contained in:
parent
af194918cf
commit
042261266f
@ -93,7 +93,7 @@ def convert_100_200(data):
|
|||||||
|
|
||||||
def _convert_dict_keys(o: Any) -> Any:
|
def _convert_dict_keys(o: Any) -> Any:
|
||||||
if isinstance(o, dict):
|
if isinstance(o, dict):
|
||||||
return {strutils.native(k): _convert_dict_keys(v) for k, v in o.items()}
|
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
|
||||||
else:
|
else:
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@ -103,7 +103,7 @@ def _convert_dict_vals(o: dict, values_to_convert: dict) -> dict:
|
|||||||
if not o or k not in o:
|
if not o or k not in o:
|
||||||
continue
|
continue
|
||||||
if v is True:
|
if v is True:
|
||||||
o[k] = strutils.native(o[k])
|
o[k] = strutils.always_str(o[k])
|
||||||
else:
|
else:
|
||||||
_convert_dict_vals(o[k], v)
|
_convert_dict_vals(o[k], v)
|
||||||
return o
|
return o
|
||||||
|
@ -7,15 +7,6 @@ from mitmproxy.types import serializable
|
|||||||
from mitmproxy.net.http import headers
|
from mitmproxy.net.http import headers
|
||||||
|
|
||||||
|
|
||||||
# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
|
|
||||||
def _native(x):
|
|
||||||
return x.decode("utf-8", "surrogateescape")
|
|
||||||
|
|
||||||
|
|
||||||
def _always_bytes(x):
|
|
||||||
return strutils.always_bytes(x, "utf-8", "surrogateescape")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageData(serializable.Serializable):
|
class MessageData(serializable.Serializable):
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, MessageData):
|
if isinstance(other, MessageData):
|
||||||
@ -142,11 +133,11 @@ class Message(serializable.Serializable):
|
|||||||
"""
|
"""
|
||||||
Version string, e.g. "HTTP/1.1"
|
Version string, e.g. "HTTP/1.1"
|
||||||
"""
|
"""
|
||||||
return _native(self.data.http_version)
|
return self.data.http_version.decode("utf-8", "surrogateescape")
|
||||||
|
|
||||||
@http_version.setter
|
@http_version.setter
|
||||||
def http_version(self, http_version):
|
def http_version(self, http_version):
|
||||||
self.data.http_version = _always_bytes(http_version)
|
self.data.http_version = strutils.always_bytes(http_version, "utf-8", "surrogateescape")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def timestamp_start(self):
|
def timestamp_start(self):
|
||||||
|
@ -115,24 +115,24 @@ class Request(message.Message):
|
|||||||
"""
|
"""
|
||||||
HTTP request method, e.g. "GET".
|
HTTP request method, e.g. "GET".
|
||||||
"""
|
"""
|
||||||
return message._native(self.data.method).upper()
|
return self.data.method.decode("utf-8", "surrogateescape").upper()
|
||||||
|
|
||||||
@method.setter
|
@method.setter
|
||||||
def method(self, method):
|
def method(self, method):
|
||||||
self.data.method = message._always_bytes(method)
|
self.data.method = strutils.always_bytes(method, "utf-8", "surrogateescape")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheme(self):
|
def scheme(self):
|
||||||
"""
|
"""
|
||||||
HTTP request scheme, which should be "http" or "https".
|
HTTP request scheme, which should be "http" or "https".
|
||||||
"""
|
"""
|
||||||
if not self.data.scheme:
|
if self.data.scheme is None:
|
||||||
return self.data.scheme
|
return None
|
||||||
return message._native(self.data.scheme)
|
return self.data.scheme.decode("utf-8", "surrogateescape")
|
||||||
|
|
||||||
@scheme.setter
|
@scheme.setter
|
||||||
def scheme(self, scheme):
|
def scheme(self, scheme):
|
||||||
self.data.scheme = message._always_bytes(scheme)
|
self.data.scheme = strutils.always_bytes(scheme, "utf-8", "surrogateescape")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def host(self):
|
def host(self):
|
||||||
@ -190,11 +190,11 @@ class Request(message.Message):
|
|||||||
if self.data.path is None:
|
if self.data.path is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return message._native(self.data.path)
|
return self.data.path.decode("utf-8", "surrogateescape")
|
||||||
|
|
||||||
@path.setter
|
@path.setter
|
||||||
def path(self, path):
|
def path(self, path):
|
||||||
self.data.path = message._always_bytes(path)
|
self.data.path = strutils.always_bytes(path, "utf-8", "surrogateescape")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self):
|
def url(self):
|
||||||
|
@ -538,7 +538,7 @@ class _Connection:
|
|||||||
self.ssl_verification_error = exceptions.InvalidCertificateException(
|
self.ssl_verification_error = exceptions.InvalidCertificateException(
|
||||||
"Certificate Verification Error for {}: {} (errno: {}, depth: {})".format(
|
"Certificate Verification Error for {}: {} (errno: {}, depth: {})".format(
|
||||||
sni,
|
sni,
|
||||||
strutils.native(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"),
|
strutils.always_str(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"),
|
||||||
errno,
|
errno,
|
||||||
err_depth
|
err_depth
|
||||||
)
|
)
|
||||||
|
@ -57,38 +57,38 @@ class WSGIAdaptor:
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError, if the content-encoding is invalid.
|
ValueError, if the content-encoding is invalid.
|
||||||
"""
|
"""
|
||||||
path = strutils.native(flow.request.path, "latin-1")
|
path = strutils.always_str(flow.request.path, "latin-1")
|
||||||
if '?' in path:
|
if '?' in path:
|
||||||
path_info, query = strutils.native(path, "latin-1").split('?', 1)
|
path_info, query = strutils.always_str(path, "latin-1").split('?', 1)
|
||||||
else:
|
else:
|
||||||
path_info = path
|
path_info = path
|
||||||
query = ''
|
query = ''
|
||||||
environ = {
|
environ = {
|
||||||
'wsgi.version': (1, 0),
|
'wsgi.version': (1, 0),
|
||||||
'wsgi.url_scheme': strutils.native(flow.request.scheme, "latin-1"),
|
'wsgi.url_scheme': strutils.always_str(flow.request.scheme, "latin-1"),
|
||||||
'wsgi.input': io.BytesIO(flow.request.content or b""),
|
'wsgi.input': io.BytesIO(flow.request.content or b""),
|
||||||
'wsgi.errors': errsoc,
|
'wsgi.errors': errsoc,
|
||||||
'wsgi.multithread': True,
|
'wsgi.multithread': True,
|
||||||
'wsgi.multiprocess': False,
|
'wsgi.multiprocess': False,
|
||||||
'wsgi.run_once': False,
|
'wsgi.run_once': False,
|
||||||
'SERVER_SOFTWARE': self.sversion,
|
'SERVER_SOFTWARE': self.sversion,
|
||||||
'REQUEST_METHOD': strutils.native(flow.request.method, "latin-1"),
|
'REQUEST_METHOD': strutils.always_str(flow.request.method, "latin-1"),
|
||||||
'SCRIPT_NAME': '',
|
'SCRIPT_NAME': '',
|
||||||
'PATH_INFO': urllib.parse.unquote(path_info),
|
'PATH_INFO': urllib.parse.unquote(path_info),
|
||||||
'QUERY_STRING': query,
|
'QUERY_STRING': query,
|
||||||
'CONTENT_TYPE': strutils.native(flow.request.headers.get('Content-Type', ''), "latin-1"),
|
'CONTENT_TYPE': strutils.always_str(flow.request.headers.get('Content-Type', ''), "latin-1"),
|
||||||
'CONTENT_LENGTH': strutils.native(flow.request.headers.get('Content-Length', ''), "latin-1"),
|
'CONTENT_LENGTH': strutils.always_str(flow.request.headers.get('Content-Length', ''), "latin-1"),
|
||||||
'SERVER_NAME': self.domain,
|
'SERVER_NAME': self.domain,
|
||||||
'SERVER_PORT': str(self.port),
|
'SERVER_PORT': str(self.port),
|
||||||
'SERVER_PROTOCOL': strutils.native(flow.request.http_version, "latin-1"),
|
'SERVER_PROTOCOL': strutils.always_str(flow.request.http_version, "latin-1"),
|
||||||
}
|
}
|
||||||
environ.update(extra)
|
environ.update(extra)
|
||||||
if flow.client_conn.address:
|
if flow.client_conn.address:
|
||||||
environ["REMOTE_ADDR"] = strutils.native(flow.client_conn.address.host, "latin-1")
|
environ["REMOTE_ADDR"] = strutils.always_str(flow.client_conn.address.host, "latin-1")
|
||||||
environ["REMOTE_PORT"] = flow.client_conn.address.port
|
environ["REMOTE_PORT"] = flow.client_conn.address.port
|
||||||
|
|
||||||
for key, value in flow.request.headers.items():
|
for key, value in flow.request.headers.items():
|
||||||
key = 'HTTP_' + strutils.native(key, "latin-1").upper().replace('-', '_')
|
key = 'HTTP_' + strutils.always_str(key, "latin-1").upper().replace('-', '_')
|
||||||
if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
|
if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
|
||||||
environ[key] = value
|
environ[key] = value
|
||||||
return environ
|
return environ
|
||||||
|
@ -1,28 +1,28 @@
|
|||||||
import re
|
import re
|
||||||
import codecs
|
import codecs
|
||||||
|
from typing import AnyStr, Optional
|
||||||
|
|
||||||
|
|
||||||
def always_bytes(unicode_or_bytes, *encode_args):
|
def always_bytes(str_or_bytes: Optional[AnyStr], *encode_args) -> Optional[bytes]:
|
||||||
if isinstance(unicode_or_bytes, str):
|
if isinstance(str_or_bytes, bytes) or str_or_bytes is None:
|
||||||
return unicode_or_bytes.encode(*encode_args)
|
return str_or_bytes
|
||||||
elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None:
|
elif isinstance(str_or_bytes, str):
|
||||||
return unicode_or_bytes
|
return str_or_bytes.encode(*encode_args)
|
||||||
else:
|
else:
|
||||||
raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__))
|
raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))
|
||||||
|
|
||||||
|
|
||||||
def native(s, *encoding_opts):
|
def always_str(str_or_bytes: Optional[AnyStr], *decode_args) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Convert :py:class:`bytes` or :py:class:`unicode` to the native
|
Returns,
|
||||||
:py:class:`str` type, using latin1 encoding if conversion is necessary.
|
str_or_bytes unmodified, if
|
||||||
|
|
||||||
https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(s, (bytes, str)):
|
if isinstance(str_or_bytes, str) or str_or_bytes is None:
|
||||||
raise TypeError("%r is neither bytes nor unicode" % s)
|
return str_or_bytes
|
||||||
if isinstance(s, bytes):
|
elif isinstance(str_or_bytes, bytes):
|
||||||
return s.decode(*encoding_opts)
|
return str_or_bytes.decode(*decode_args)
|
||||||
return s
|
else:
|
||||||
|
raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))
|
||||||
|
|
||||||
|
|
||||||
# Translate control characters to "safe" characters. This implementation initially
|
# Translate control characters to "safe" characters. This implementation initially
|
||||||
@ -135,7 +135,7 @@ def hexdump(s):
|
|||||||
part = s[i:i + 16]
|
part = s[i:i + 16]
|
||||||
x = " ".join("{:0=2x}".format(i) for i in part)
|
x = " ".join("{:0=2x}".format(i) for i in part)
|
||||||
x = x.ljust(47) # 16*2 + 15
|
x = x.ljust(47) # 16*2 + 15
|
||||||
part_repr = native(escape_control_characters(
|
part_repr = always_str(escape_control_characters(
|
||||||
part.decode("ascii", "replace").replace(u"\ufffd", u"."),
|
part.decode("ascii", "replace").replace(u"\ufffd", u"."),
|
||||||
False
|
False
|
||||||
))
|
))
|
||||||
|
@ -61,7 +61,7 @@ class LogCtx:
|
|||||||
for line in strutils.hexdump(data):
|
for line in strutils.hexdump(data):
|
||||||
self("\t%s %s %s" % line)
|
self("\t%s %s %s" % line)
|
||||||
else:
|
else:
|
||||||
data = strutils.native(
|
data = strutils.always_str(
|
||||||
strutils.escape_control_characters(
|
strutils.escape_control_characters(
|
||||||
data
|
data
|
||||||
.decode("ascii", "replace")
|
.decode("ascii", "replace")
|
||||||
|
@ -44,7 +44,7 @@ class SSLInfo:
|
|||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
parts = [
|
parts = [
|
||||||
"Application Layer Protocol: %s" % strutils.native(self.alp, "utf8"),
|
"Application Layer Protocol: %s" % strutils.always_str(self.alp, "utf8"),
|
||||||
"Cipher: %s, %s bit, %s" % self.cipher,
|
"Cipher: %s, %s bit, %s" % self.cipher,
|
||||||
"SSL certificate chain:"
|
"SSL certificate chain:"
|
||||||
]
|
]
|
||||||
@ -53,24 +53,24 @@ class SSLInfo:
|
|||||||
parts.append("\tSubject: ")
|
parts.append("\tSubject: ")
|
||||||
for cn in i.get_subject().get_components():
|
for cn in i.get_subject().get_components():
|
||||||
parts.append("\t\t%s=%s" % (
|
parts.append("\t\t%s=%s" % (
|
||||||
strutils.native(cn[0], "utf8"),
|
strutils.always_str(cn[0], "utf8"),
|
||||||
strutils.native(cn[1], "utf8"))
|
strutils.always_str(cn[1], "utf8"))
|
||||||
)
|
)
|
||||||
parts.append("\tIssuer: ")
|
parts.append("\tIssuer: ")
|
||||||
for cn in i.get_issuer().get_components():
|
for cn in i.get_issuer().get_components():
|
||||||
parts.append("\t\t%s=%s" % (
|
parts.append("\t\t%s=%s" % (
|
||||||
strutils.native(cn[0], "utf8"),
|
strutils.always_str(cn[0], "utf8"),
|
||||||
strutils.native(cn[1], "utf8"))
|
strutils.always_str(cn[1], "utf8"))
|
||||||
)
|
)
|
||||||
parts.extend(
|
parts.extend(
|
||||||
[
|
[
|
||||||
"\tVersion: %s" % i.get_version(),
|
"\tVersion: %s" % i.get_version(),
|
||||||
"\tValidity: %s - %s" % (
|
"\tValidity: %s - %s" % (
|
||||||
strutils.native(i.get_notBefore(), "utf8"),
|
strutils.always_str(i.get_notBefore(), "utf8"),
|
||||||
strutils.native(i.get_notAfter(), "utf8")
|
strutils.always_str(i.get_notAfter(), "utf8")
|
||||||
),
|
),
|
||||||
"\tSerial: %s" % i.get_serial_number(),
|
"\tSerial: %s" % i.get_serial_number(),
|
||||||
"\tAlgorithm: %s" % strutils.native(i.get_signature_algorithm(), "utf8")
|
"\tAlgorithm: %s" % strutils.always_str(i.get_signature_algorithm(), "utf8")
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
pk = i.get_pubkey()
|
pk = i.get_pubkey()
|
||||||
@ -82,7 +82,7 @@ class SSLInfo:
|
|||||||
parts.append("\tPubkey: %s bit %s" % (pk.bits(), t))
|
parts.append("\tPubkey: %s bit %s" % (pk.bits(), t))
|
||||||
s = certs.SSLCert(i)
|
s = certs.SSLCert(i)
|
||||||
if s.altnames:
|
if s.altnames:
|
||||||
parts.append("\tSANs: %s" % " ".join(strutils.native(n, "utf8") for n in s.altnames))
|
parts.append("\tSANs: %s" % " ".join(strutils.always_str(n, "utf8") for n in s.altnames))
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,7 +55,20 @@ class TestResponseCore:
|
|||||||
_test_passthrough_attr(tresp(), "status_code")
|
_test_passthrough_attr(tresp(), "status_code")
|
||||||
|
|
||||||
def test_reason(self):
|
def test_reason(self):
|
||||||
_test_decoded_attr(tresp(), "reason")
|
resp = tresp()
|
||||||
|
assert resp.reason == "OK"
|
||||||
|
|
||||||
|
resp.reason = "ABC"
|
||||||
|
assert resp.data.reason == b"ABC"
|
||||||
|
|
||||||
|
resp.reason = b"DEF"
|
||||||
|
assert resp.data.reason == b"DEF"
|
||||||
|
|
||||||
|
resp.reason = None
|
||||||
|
assert resp.data.reason is None
|
||||||
|
|
||||||
|
resp.data.reason = b'cr\xe9e'
|
||||||
|
assert resp.reason == "crée"
|
||||||
|
|
||||||
|
|
||||||
class TestResponseUtils:
|
class TestResponseUtils:
|
||||||
|
@ -11,11 +11,12 @@ def test_always_bytes():
|
|||||||
strutils.always_bytes(42, "ascii")
|
strutils.always_bytes(42, "ascii")
|
||||||
|
|
||||||
|
|
||||||
def test_native():
|
def test_always_str():
|
||||||
with tutils.raises(TypeError):
|
with tutils.raises(TypeError):
|
||||||
strutils.native(42)
|
strutils.always_str(42)
|
||||||
assert strutils.native(u"foo") == u"foo"
|
assert strutils.always_str("foo") == "foo"
|
||||||
assert strutils.native(b"foo") == u"foo"
|
assert strutils.always_str(b"foo") == "foo"
|
||||||
|
assert strutils.always_str(None) is None
|
||||||
|
|
||||||
|
|
||||||
def test_escape_control_characters():
|
def test_escape_control_characters():
|
||||||
|
Loading…
Reference in New Issue
Block a user