mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
avoid TLS/SSL ambiguity for Cert class
This commit is contained in:
parent
d15e96dee1
commit
4fb894cad4
@ -43,7 +43,7 @@ def extract(cut: str, f: flow.Flow) -> typing.Union[str, bytes]:
|
|||||||
return part
|
return part
|
||||||
elif isinstance(part, bool):
|
elif isinstance(part, bool):
|
||||||
return "true" if part else "false"
|
return "true" if part else "false"
|
||||||
elif isinstance(part, certs.SSLCert):
|
elif isinstance(part, certs.Cert):
|
||||||
return part.to_pem().decode("ascii")
|
return part.to_pem().decode("ascii")
|
||||||
current = part
|
current = part
|
||||||
return str(current or "")
|
return str(current or "")
|
||||||
|
@ -112,7 +112,7 @@ def dummy_cert(privkey, cacert, commonname, sans):
|
|||||||
[OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)])
|
[OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)])
|
||||||
cert.set_pubkey(cacert.get_pubkey())
|
cert.set_pubkey(cacert.get_pubkey())
|
||||||
cert.sign(privkey, "sha256")
|
cert.sign(privkey, "sha256")
|
||||||
return SSLCert(cert)
|
return Cert(cert)
|
||||||
|
|
||||||
|
|
||||||
class CertStoreEntry:
|
class CertStoreEntry:
|
||||||
@ -249,7 +249,7 @@ class CertStore:
|
|||||||
def add_cert_file(self, spec: str, path: str) -> None:
|
def add_cert_file(self, spec: str, path: str) -> None:
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
raw = f.read()
|
raw = f.read()
|
||||||
cert = SSLCert(
|
cert = Cert(
|
||||||
OpenSSL.crypto.load_certificate(
|
OpenSSL.crypto.load_certificate(
|
||||||
OpenSSL.crypto.FILETYPE_PEM,
|
OpenSSL.crypto.FILETYPE_PEM,
|
||||||
raw))
|
raw))
|
||||||
@ -345,7 +345,7 @@ class _GeneralNames(univ.SequenceOf):
|
|||||||
constraint.ValueSizeConstraint(1, 1024)
|
constraint.ValueSizeConstraint(1, 1024)
|
||||||
|
|
||||||
|
|
||||||
class SSLCert(serializable.Serializable):
|
class Cert(serializable.Serializable):
|
||||||
|
|
||||||
def __init__(self, cert):
|
def __init__(self, cert):
|
||||||
"""
|
"""
|
||||||
|
@ -87,8 +87,8 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
|
|||||||
id=str,
|
id=str,
|
||||||
address=tuple,
|
address=tuple,
|
||||||
tls_established=bool,
|
tls_established=bool,
|
||||||
clientcert=certs.SSLCert,
|
clientcert=certs.Cert,
|
||||||
mitmcert=certs.SSLCert,
|
mitmcert=certs.Cert,
|
||||||
timestamp_start=float,
|
timestamp_start=float,
|
||||||
timestamp_tls_setup=float,
|
timestamp_tls_setup=float,
|
||||||
timestamp_end=float,
|
timestamp_end=float,
|
||||||
@ -215,7 +215,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
|||||||
ip_address=tuple,
|
ip_address=tuple,
|
||||||
source_address=tuple,
|
source_address=tuple,
|
||||||
tls_established=bool,
|
tls_established=bool,
|
||||||
cert=certs.SSLCert,
|
cert=certs.Cert,
|
||||||
sni=str,
|
sni=str,
|
||||||
alpn_proto_negotiated=bytes,
|
alpn_proto_negotiated=bytes,
|
||||||
tls_version=str,
|
tls_version=str,
|
||||||
|
@ -161,8 +161,8 @@ def convert_5_6(data):
|
|||||||
data["server_conn"]["tls_established"] = data["server_conn"].pop("ssl_established")
|
data["server_conn"]["tls_established"] = data["server_conn"].pop("ssl_established")
|
||||||
data["server_conn"]["timestamp_tls_setup"] = data["server_conn"].pop("timestamp_ssl_setup")
|
data["server_conn"]["timestamp_tls_setup"] = data["server_conn"].pop("timestamp_ssl_setup")
|
||||||
if data["server_conn"]["via"]:
|
if data["server_conn"]["via"]:
|
||||||
data["server_conn"]["via"]["tls_established"] = data["server_conn"]["via"].pop("ssl_established", None)
|
data["server_conn"]["via"]["tls_established"] = data["server_conn"]["via"].pop("ssl_established")
|
||||||
data["server_conn"]["via"]["timestamp_tls_setup"] = data["server_conn"]["via"].pop("timestamp_ssl_setup", None)
|
data["server_conn"]["via"]["timestamp_tls_setup"] = data["server_conn"]["via"].pop("timestamp_ssl_setup")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@ -400,11 +400,11 @@ class TCPClient(_Connection):
|
|||||||
else:
|
else:
|
||||||
raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
|
raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
|
||||||
|
|
||||||
self.cert = certs.SSLCert(self.connection.get_peer_certificate())
|
self.cert = certs.Cert(self.connection.get_peer_certificate())
|
||||||
|
|
||||||
# Keep all server certificates in a list
|
# Keep all server certificates in a list
|
||||||
for i in self.connection.get_peer_cert_chain():
|
for i in self.connection.get_peer_cert_chain():
|
||||||
self.server_certs.append(certs.SSLCert(i))
|
self.server_certs.append(certs.Cert(i))
|
||||||
|
|
||||||
self.tls_established = True
|
self.tls_established = True
|
||||||
self.rfile.set_descriptor(self.connection)
|
self.rfile.set_descriptor(self.connection)
|
||||||
@ -510,7 +510,7 @@ class BaseHandler(_Connection):
|
|||||||
self.tls_established = True
|
self.tls_established = True
|
||||||
cert = self.connection.get_peer_certificate()
|
cert = self.connection.get_peer_certificate()
|
||||||
if cert:
|
if cert:
|
||||||
self.clientcert = certs.SSLCert(cert)
|
self.clientcert = certs.Cert(cert)
|
||||||
self.rfile.set_descriptor(self.connection)
|
self.rfile.set_descriptor(self.connection)
|
||||||
self.wfile.set_descriptor(self.connection)
|
self.wfile.set_descriptor(self.connection)
|
||||||
|
|
||||||
|
@ -213,7 +213,7 @@ def create_client_context(
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
if is_cert_verified and depth == 0:
|
if is_cert_verified and depth == 0:
|
||||||
# Verify hostname of leaf certificate.
|
# Verify hostname of leaf certificate.
|
||||||
cert = certs.SSLCert(x509)
|
cert = certs.Cert(x509)
|
||||||
try:
|
try:
|
||||||
crt = dict(
|
crt = dict(
|
||||||
subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in cert.altnames]
|
subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in cert.altnames]
|
||||||
@ -270,17 +270,17 @@ def create_client_context(
|
|||||||
|
|
||||||
|
|
||||||
def create_server_context(
|
def create_server_context(
|
||||||
cert: typing.Union[certs.SSLCert, str],
|
cert: typing.Union[certs.Cert, str],
|
||||||
key: SSL.PKey,
|
key: SSL.PKey,
|
||||||
handle_sni: typing.Optional[typing.Callable[[SSL.Connection], None]] = None,
|
handle_sni: typing.Optional[typing.Callable[[SSL.Connection], None]] = None,
|
||||||
request_client_cert: bool = False,
|
request_client_cert: bool = False,
|
||||||
chain_file=None,
|
chain_file=None,
|
||||||
dhparams=None,
|
dhparams=None,
|
||||||
extra_chain_certs: typing.Iterable[certs.SSLCert] = None,
|
extra_chain_certs: typing.Iterable[certs.Cert] = None,
|
||||||
**sslctx_kwargs
|
**sslctx_kwargs
|
||||||
) -> SSL.Context:
|
) -> SSL.Context:
|
||||||
"""
|
"""
|
||||||
cert: A certs.SSLCert object or the path to a certificate
|
cert: A certs.Cert object or the path to a certificate
|
||||||
chain file.
|
chain file.
|
||||||
|
|
||||||
handle_sni: SNI handler, should take a connection object. Server
|
handle_sni: SNI handler, should take a connection object. Server
|
||||||
@ -321,7 +321,7 @@ def create_server_context(
|
|||||||
)
|
)
|
||||||
|
|
||||||
context.use_privatekey(key)
|
context.use_privatekey(key)
|
||||||
if isinstance(cert, certs.SSLCert):
|
if isinstance(cert, certs.Cert):
|
||||||
context.use_certificate(cert.x509)
|
context.use_certificate(cert.x509)
|
||||||
else:
|
else:
|
||||||
context.use_certificate_chain_file(cert)
|
context.use_certificate_chain_file(cert)
|
||||||
|
@ -79,7 +79,7 @@ class SSLInfo:
|
|||||||
}
|
}
|
||||||
t = types.get(pk.type(), "Uknown")
|
t = types.get(pk.type(), "Uknown")
|
||||||
parts.append("\tPubkey: %s bit %s" % (pk.bits(), t))
|
parts.append("\tPubkey: %s bit %s" % (pk.bits(), t))
|
||||||
s = certs.SSLCert(i)
|
s = certs.Cert(i)
|
||||||
if s.altnames:
|
if s.altnames:
|
||||||
parts.append("\tSANs: %s" % " ".join(strutils.always_str(n, "utf8") for n in s.altnames))
|
parts.append("\tSANs: %s" % " ".join(strutils.always_str(n, "utf8") for n in s.altnames))
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
@ -55,7 +55,7 @@ def test_extract():
|
|||||||
|
|
||||||
with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
||||||
d = f.read()
|
d = f.read()
|
||||||
c1 = certs.SSLCert.from_pem(d)
|
c1 = certs.Cert.from_pem(d)
|
||||||
tf.server_conn.cert = c1
|
tf.server_conn.cert = c1
|
||||||
assert "CERTIFICATE" in cut.extract("server_conn.cert", tf)
|
assert "CERTIFICATE" in cut.extract("server_conn.cert", tf)
|
||||||
|
|
||||||
|
@ -143,9 +143,9 @@ class TcpMixin:
|
|||||||
|
|
||||||
# Test that we get the original SSL cert
|
# Test that we get the original SSL cert
|
||||||
if self.ssl:
|
if self.ssl:
|
||||||
i_cert = certs.SSLCert(i.sslinfo.certchain[0])
|
i_cert = certs.Cert(i.sslinfo.certchain[0])
|
||||||
i2_cert = certs.SSLCert(i2.sslinfo.certchain[0])
|
i2_cert = certs.Cert(i2.sslinfo.certchain[0])
|
||||||
n_cert = certs.SSLCert(n.sslinfo.certchain[0])
|
n_cert = certs.Cert(n.sslinfo.certchain[0])
|
||||||
|
|
||||||
assert i_cert == i2_cert
|
assert i_cert == i2_cert
|
||||||
assert i_cert != n_cert
|
assert i_cert != n_cert
|
||||||
@ -188,9 +188,9 @@ class TcpMixin:
|
|||||||
|
|
||||||
# Test that we get the original SSL cert
|
# Test that we get the original SSL cert
|
||||||
if self.ssl:
|
if self.ssl:
|
||||||
i_cert = certs.SSLCert(i.sslinfo.certchain[0])
|
i_cert = certs.Cert(i.sslinfo.certchain[0])
|
||||||
i2_cert = certs.SSLCert(i2.sslinfo.certchain[0])
|
i2_cert = certs.Cert(i2.sslinfo.certchain[0])
|
||||||
n_cert = certs.SSLCert(n.sslinfo.certchain[0])
|
n_cert = certs.Cert(n.sslinfo.certchain[0])
|
||||||
|
|
||||||
assert i_cert == i2_cert
|
assert i_cert == i2_cert
|
||||||
assert i_cert != n_cert
|
assert i_cert != n_cert
|
||||||
@ -1149,7 +1149,7 @@ class AddUpstreamCertsToClientChainMixin:
|
|||||||
def test_add_upstream_certs_to_client_chain(self):
|
def test_add_upstream_certs_to_client_chain(self):
|
||||||
with open(self.servercert, "rb") as f:
|
with open(self.servercert, "rb") as f:
|
||||||
d = f.read()
|
d = f.read()
|
||||||
upstreamCert = certs.SSLCert.from_pem(d)
|
upstreamCert = certs.Cert.from_pem(d)
|
||||||
p = self.pathoc()
|
p = self.pathoc()
|
||||||
with p.connect():
|
with p.connect():
|
||||||
upstream_cert_found_in_client_chain = False
|
upstream_cert_found_in_client_chain = False
|
||||||
|
@ -136,18 +136,18 @@ class TestDummyCert:
|
|||||||
assert r.altnames == []
|
assert r.altnames == []
|
||||||
|
|
||||||
|
|
||||||
class TestSSLCert:
|
class TestCert:
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
||||||
d = f.read()
|
d = f.read()
|
||||||
c1 = certs.SSLCert.from_pem(d)
|
c1 = certs.Cert.from_pem(d)
|
||||||
assert c1.cn == b"google.com"
|
assert c1.cn == b"google.com"
|
||||||
assert len(c1.altnames) == 436
|
assert len(c1.altnames) == 436
|
||||||
|
|
||||||
with open(tutils.test_data.path("mitmproxy/net/data/text_cert_2"), "rb") as f:
|
with open(tutils.test_data.path("mitmproxy/net/data/text_cert_2"), "rb") as f:
|
||||||
d = f.read()
|
d = f.read()
|
||||||
c2 = certs.SSLCert.from_pem(d)
|
c2 = certs.Cert.from_pem(d)
|
||||||
assert c2.cn == b"www.inode.co.nz"
|
assert c2.cn == b"www.inode.co.nz"
|
||||||
assert len(c2.altnames) == 2
|
assert len(c2.altnames) == 2
|
||||||
assert c2.digest("sha1")
|
assert c2.digest("sha1")
|
||||||
@ -165,20 +165,20 @@ class TestSSLCert:
|
|||||||
def test_err_broken_sans(self):
|
def test_err_broken_sans(self):
|
||||||
with open(tutils.test_data.path("mitmproxy/net/data/text_cert_weird1"), "rb") as f:
|
with open(tutils.test_data.path("mitmproxy/net/data/text_cert_weird1"), "rb") as f:
|
||||||
d = f.read()
|
d = f.read()
|
||||||
c = certs.SSLCert.from_pem(d)
|
c = certs.Cert.from_pem(d)
|
||||||
# This breaks unless we ignore a decoding error.
|
# This breaks unless we ignore a decoding error.
|
||||||
assert c.altnames is not None
|
assert c.altnames is not None
|
||||||
|
|
||||||
def test_der(self):
|
def test_der(self):
|
||||||
with open(tutils.test_data.path("mitmproxy/net/data/dercert"), "rb") as f:
|
with open(tutils.test_data.path("mitmproxy/net/data/dercert"), "rb") as f:
|
||||||
d = f.read()
|
d = f.read()
|
||||||
s = certs.SSLCert.from_der(d)
|
s = certs.Cert.from_der(d)
|
||||||
assert s.cn
|
assert s.cn
|
||||||
|
|
||||||
def test_state(self):
|
def test_state(self):
|
||||||
with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
||||||
d = f.read()
|
d = f.read()
|
||||||
c = certs.SSLCert.from_pem(d)
|
c = certs.Cert.from_pem(d)
|
||||||
|
|
||||||
c.get_state()
|
c.get_state()
|
||||||
c2 = c.copy()
|
c2 = c.copy()
|
||||||
@ -188,6 +188,6 @@ class TestSSLCert:
|
|||||||
assert c == c2
|
assert c == c2
|
||||||
assert c is not c2
|
assert c is not c2
|
||||||
|
|
||||||
x = certs.SSLCert('')
|
x = certs.Cert('')
|
||||||
x.set_state(a)
|
x.set_state(a)
|
||||||
assert x == c
|
assert x == c
|
||||||
|
Loading…
Reference in New Issue
Block a user