mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 23:09:44 +00:00
sni is now str, not bytes
This commit is contained in:
parent
8287ce7e6d
commit
64a867973d
@ -8,7 +8,6 @@ import six
|
||||
|
||||
from mitmproxy import stateobject
|
||||
from netlib import certutils
|
||||
from netlib import strutils
|
||||
from netlib import tcp
|
||||
|
||||
|
||||
@ -162,7 +161,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||
source_address=tcp.Address,
|
||||
ssl_established=bool,
|
||||
cert=certutils.SSLCert,
|
||||
sni=bytes,
|
||||
sni=str,
|
||||
timestamp_start=float,
|
||||
timestamp_tcp_setup=float,
|
||||
timestamp_ssl_setup=float,
|
||||
@ -206,6 +205,8 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||
self.wfile.flush()
|
||||
|
||||
def establish_ssl(self, clientcerts, sni, **kwargs):
|
||||
if sni and not isinstance(sni, six.string_types):
|
||||
raise ValueError("sni must be str, not " + type(sni).__name__)
|
||||
clientcert = None
|
||||
if clientcerts:
|
||||
if os.path.isfile(clientcerts):
|
||||
@ -217,7 +218,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||
if os.path.exists(path):
|
||||
clientcert = path
|
||||
|
||||
self.convert_to_ssl(cert=clientcert, sni=strutils.always_bytes(sni), **kwargs)
|
||||
self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs)
|
||||
self.sni = sni
|
||||
self.timestamp_ssl_setup = time.time()
|
||||
|
||||
|
@ -9,6 +9,7 @@ from mitmproxy.models.connections import ClientConnection
|
||||
from mitmproxy.models.connections import ServerConnection
|
||||
|
||||
from netlib import version
|
||||
from typing import Optional # noqa
|
||||
|
||||
|
||||
class Error(stateobject.StateObject):
|
||||
@ -70,18 +71,13 @@ class Flow(stateobject.StateObject):
|
||||
def __init__(self, type, client_conn, server_conn, live=None):
|
||||
self.type = type
|
||||
self.id = str(uuid.uuid4())
|
||||
self.client_conn = client_conn
|
||||
"""@type: ClientConnection"""
|
||||
self.server_conn = server_conn
|
||||
"""@type: ServerConnection"""
|
||||
self.client_conn = client_conn # type: ClientConnection
|
||||
self.server_conn = server_conn # type: ServerConnection
|
||||
self.live = live
|
||||
"""@type: LiveConnection"""
|
||||
|
||||
self.error = None
|
||||
"""@type: Error"""
|
||||
self.intercepted = False
|
||||
"""@type: bool"""
|
||||
self._backup = None
|
||||
self.error = None # type: Error
|
||||
self.intercepted = False # type: bool
|
||||
self._backup = None # type: Optional[Flow]
|
||||
self.reply = None
|
||||
|
||||
_stateobject_attributes = dict(
|
||||
|
@ -10,6 +10,7 @@ import netlib.exceptions
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy.contrib.tls import _constructs
|
||||
from mitmproxy.protocol import base
|
||||
from netlib import utils
|
||||
|
||||
|
||||
# taken from https://testssl.sh/openssl-rfc.mappping.html
|
||||
@ -274,10 +275,11 @@ class TlsClientHello(object):
|
||||
is_valid_sni_extension = (
|
||||
extension.type == 0x00 and
|
||||
len(extension.server_names) == 1 and
|
||||
extension.server_names[0].type == 0
|
||||
extension.server_names[0].type == 0 and
|
||||
utils.is_valid_host(extension.server_names[0].name)
|
||||
)
|
||||
if is_valid_sni_extension:
|
||||
return extension.server_names[0].name
|
||||
return extension.server_names[0].name.decode("idna")
|
||||
|
||||
@property
|
||||
def alpn_protocols(self):
|
||||
@ -403,13 +405,14 @@ class TlsLayer(base.Layer):
|
||||
self._establish_tls_with_server()
|
||||
|
||||
def set_server_tls(self, server_tls, sni=None):
|
||||
# type: (bool, Union[six.text_type, None, False]) -> None
|
||||
"""
|
||||
Set the TLS settings for the next server connection that will be established.
|
||||
This function will not alter an existing connection.
|
||||
|
||||
Args:
|
||||
server_tls: Shall we establish TLS with the server?
|
||||
sni: ``bytes`` for a custom SNI value,
|
||||
sni: ``str`` for a custom SNI value,
|
||||
``None`` for the client SNI value,
|
||||
``False`` if no SNI value should be sent.
|
||||
"""
|
||||
@ -602,9 +605,9 @@ class TlsLayer(base.Layer):
|
||||
host = upstream_cert.cn.decode("utf8").encode("idna")
|
||||
# Also add SNI values.
|
||||
if self._client_hello.sni:
|
||||
sans.add(self._client_hello.sni)
|
||||
sans.add(self._client_hello.sni.encode("idna"))
|
||||
if self._custom_server_sni:
|
||||
sans.add(self._custom_server_sni)
|
||||
sans.add(self._custom_server_sni.encode("idna"))
|
||||
|
||||
# RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity.
|
||||
# In other words, the Common Name is irrelevant then.
|
||||
|
@ -676,7 +676,7 @@ class TCPClient(_Connection):
|
||||
self.connection = SSL.Connection(context, self.connection)
|
||||
if sni:
|
||||
self.sni = sni
|
||||
self.connection.set_tlsext_host_name(sni)
|
||||
self.connection.set_tlsext_host_name(sni.encode("idna"))
|
||||
self.connection.set_connect_state()
|
||||
try:
|
||||
self.connection.do_handshake()
|
||||
@ -705,7 +705,7 @@ class TCPClient(_Connection):
|
||||
if self.cert.cn:
|
||||
crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]]
|
||||
if sni:
|
||||
hostname = sni.decode("ascii", "strict")
|
||||
hostname = sni
|
||||
else:
|
||||
hostname = "no-hostname"
|
||||
ssl_match_hostname.match_hostname(crt, hostname)
|
||||
|
@ -73,11 +73,9 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
|
||||
|
||||
|
||||
def is_valid_host(host):
|
||||
# type: (bytes) -> bool
|
||||
"""
|
||||
Checks if a hostname is valid.
|
||||
|
||||
Args:
|
||||
host (bytes): The hostname
|
||||
"""
|
||||
try:
|
||||
host.decode("idna")
|
||||
|
@ -89,7 +89,10 @@ class PathodHandler(tcp.BaseHandler):
|
||||
self.http2_framedump = http2_framedump
|
||||
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
sni = connection.get_servername()
|
||||
if sni:
|
||||
sni = sni.decode("idna")
|
||||
self.sni = sni
|
||||
|
||||
def http_serve_crafted(self, crafted, logctx):
|
||||
error, crafted = self.server.check_policy(
|
||||
|
@ -100,10 +100,10 @@ class CommonMixin:
|
||||
if not self.ssl:
|
||||
return
|
||||
|
||||
f = self.pathod("304", sni=b"testserver.com")
|
||||
f = self.pathod("304", sni="testserver.com")
|
||||
assert f.status_code == 304
|
||||
log = self.server.last_log()
|
||||
assert log["request"]["sni"] == b"testserver.com"
|
||||
assert log["request"]["sni"] == "testserver.com"
|
||||
|
||||
|
||||
class TcpMixin:
|
||||
@ -498,7 +498,7 @@ class TestHttps2Http(tservers.ReverseProxyTest):
|
||||
assert p.request("get:'/p/200'").status_code == 200
|
||||
|
||||
def test_sni(self):
|
||||
p = self.pathoc(ssl=True, sni=b"example.com")
|
||||
p = self.pathoc(ssl=True, sni="example.com")
|
||||
assert p.request("get:'/p/200'").status_code == 200
|
||||
assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog)
|
||||
|
||||
|
@ -130,7 +130,7 @@ def tserver_conn():
|
||||
timestamp_ssl_setup=3,
|
||||
timestamp_end=4,
|
||||
ssl_established=False,
|
||||
sni=b"address",
|
||||
sni="address",
|
||||
via=None
|
||||
))
|
||||
c.reply = controller.DummyReply()
|
||||
|
@ -169,7 +169,7 @@ class TestServerSSL(tservers.ServerTestBase):
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL)
|
||||
c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL)
|
||||
testval = b"echo!\n"
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
@ -179,7 +179,7 @@ class TestServerSSL(tservers.ServerTestBase):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
assert not c.get_current_cipher()
|
||||
c.convert_to_ssl(sni=b"foo.com")
|
||||
c.convert_to_ssl(sni="foo.com")
|
||||
ret = c.get_current_cipher()
|
||||
assert ret
|
||||
assert "AES" in ret[0]
|
||||
@ -195,7 +195,7 @@ class TestSSLv3Only(tservers.ServerTestBase):
|
||||
def test_failure(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com")
|
||||
tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com")
|
||||
|
||||
|
||||
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
|
||||
@ -238,7 +238,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
|
||||
with c.connect():
|
||||
with tutils.raises(InvalidCertificateException):
|
||||
c.convert_to_ssl(
|
||||
sni=b"example.mitmproxy.org",
|
||||
sni="example.mitmproxy.org",
|
||||
verify_options=SSL.VERIFY_PEER,
|
||||
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
|
||||
)
|
||||
@ -272,7 +272,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
|
||||
with c.connect():
|
||||
with tutils.raises(InvalidCertificateException):
|
||||
c.convert_to_ssl(
|
||||
sni=b"mitmproxy.org",
|
||||
sni="mitmproxy.org",
|
||||
verify_options=SSL.VERIFY_PEER,
|
||||
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
|
||||
)
|
||||
@ -291,7 +291,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_ssl(
|
||||
sni=b"example.mitmproxy.org",
|
||||
sni="example.mitmproxy.org",
|
||||
verify_options=SSL.VERIFY_PEER,
|
||||
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
|
||||
)
|
||||
@ -307,7 +307,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_ssl(
|
||||
sni=b"example.mitmproxy.org",
|
||||
sni="example.mitmproxy.org",
|
||||
verify_options=SSL.VERIFY_PEER,
|
||||
ca_path=tutils.test_data.path("data/verificationcerts/")
|
||||
)
|
||||
@ -371,8 +371,8 @@ class TestSNI(tservers.ServerTestBase):
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_ssl(sni=b"foo.com")
|
||||
assert c.sni == b"foo.com"
|
||||
c.convert_to_ssl(sni="foo.com")
|
||||
assert c.sni == "foo.com"
|
||||
assert c.rfile.readline() == b"foo.com"
|
||||
|
||||
|
||||
@ -385,7 +385,7 @@ class TestServerCipherList(tservers.ServerTestBase):
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_ssl(sni=b"foo.com")
|
||||
c.convert_to_ssl(sni="foo.com")
|
||||
assert c.rfile.readline() == b"['RC4-SHA']"
|
||||
|
||||
|
||||
@ -405,7 +405,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase):
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_ssl(sni=b"foo.com")
|
||||
c.convert_to_ssl(sni="foo.com")
|
||||
assert b"RC4-SHA" in c.rfile.readline()
|
||||
|
||||
|
||||
@ -418,7 +418,7 @@ class TestServerCipherListError(tservers.ServerTestBase):
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com")
|
||||
tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com")
|
||||
|
||||
|
||||
class TestClientCipherListError(tservers.ServerTestBase):
|
||||
@ -433,7 +433,7 @@ class TestClientCipherListError(tservers.ServerTestBase):
|
||||
tutils.raises(
|
||||
"cipher specification",
|
||||
c.convert_to_ssl,
|
||||
sni=b"foo.com",
|
||||
sni="foo.com",
|
||||
cipher_list="bogus"
|
||||
)
|
||||
|
||||
|
@ -54,10 +54,10 @@ class TestDaemonSSL(PathocTestDaemon):
|
||||
def test_sni(self):
|
||||
self.tval(
|
||||
["get:/p/200"],
|
||||
sni=b"foobar.com"
|
||||
sni="foobar.com"
|
||||
)
|
||||
log = self.d.log()
|
||||
assert log[0]["request"]["sni"] == b"foobar.com"
|
||||
assert log[0]["request"]["sni"] == "foobar.com"
|
||||
|
||||
def test_showssl(self):
|
||||
assert "certificate chain" in self.tval(["get:/p/200"], showssl=True)
|
||||
|
Loading…
Reference in New Issue
Block a user