diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index 5a737b618..572a12d03 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -4,6 +4,7 @@ import time import datetime import ipaddress import sys +import typing from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode @@ -122,6 +123,11 @@ class CertStoreEntry: self.chain_file = chain_file +TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs) +TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans) +TCertId = typing.Union[TCustomCertId, TGeneratedCertId] + + class CertStore: """ @@ -139,7 +145,7 @@ class CertStore: self.default_ca = default_ca self.default_chain_file = default_chain_file self.dhparams = dhparams - self.certs = dict() + self.certs = {} # type: typing.Dict[TCertId, CertStoreEntry] self.expire_queue = [] def expire(self, entry): @@ -240,7 +246,7 @@ class CertStore: return key, ca - def add_cert_file(self, spec, path): + def add_cert_file(self, spec: str, path: str) -> None: with open(path, "rb") as f: raw = f.read() cert = SSLCert( @@ -255,10 +261,10 @@ class CertStore: privatekey = self.default_privatekey self.add_cert( CertStoreEntry(cert, privatekey, path), - spec + spec.encode("idna") ) - def add_cert(self, entry, *names): + def add_cert(self, entry: CertStoreEntry, *names: bytes): """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. @@ -271,21 +277,18 @@ class CertStore: self.certs[i] = entry @staticmethod - def asterisk_forms(dn): - if dn is None: - return [] + def asterisk_forms(dn: bytes) -> typing.List[bytes]: + """ + Return all asterisk forms for a domain. For example, for www.example.com this will return + [b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted. + """ parts = dn.split(b".") - parts.reverse() - curr_dn = b"" - dn_forms = [b"*"] - for part in parts[:-1]: - curr_dn = b"." + part + curr_dn # .example.com - dn_forms.append(b"*" + curr_dn) # *.example.com - if parts[-1] != b"*": - dn_forms.append(parts[-1] + curr_dn) - return dn_forms + ret = [dn] + for i in range(1, len(parts)): + ret.append(b"*." + b".".join(parts[i:])) + return ret - def get_cert(self, commonname, sans): + def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes]): """ Returns an (cert, privkey, cert_chain) tuple. @@ -295,9 +298,12 @@ class CertStore: sans: A list of Subject Alternate Names. """ - potential_keys = self.asterisk_forms(commonname) + potential_keys = [] # type: typing.List[TCertId] + if commonname: + potential_keys.extend(self.asterisk_forms(commonname)) for s in sans: potential_keys.extend(self.asterisk_forms(s)) + potential_keys.append(b"*") potential_keys.append((commonname, tuple(sans))) name = next( diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py index affdf221f..8dce9bcd5 100644 --- a/test/mitmproxy/proxy/test_server.py +++ b/test/mitmproxy/proxy/test_server.py @@ -479,7 +479,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest): ssl = True ssloptions = pathod.SSLOptions( certs=[ - (b"*", tutils.test_data.path("mitmproxy/data/no_common_name.pem")) + ("*", tutils.test_data.path("mitmproxy/data/no_common_name.pem")) ] ) @@ -1142,7 +1142,7 @@ class AddUpstreamCertsToClientChainMixin: ssloptions = pathod.SSLOptions( cn=b"example.mitmproxy.org", certs=[ - (b"example.mitmproxy.org", servercert) + ("example.mitmproxy.org", servercert) ] ) diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index 88c495613..693bebc60 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -102,7 +102,7 @@ class TestCertStore: dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) dcp = tmpdir.join("dc") dcp.write(dc[0].to_pem()) - ca1.add_cert_file(b"foo.com", str(dcp)) + ca1.add_cert_file("foo.com", str(dcp)) ret = ca1.get_cert(b"foo.com", []) assert ret[0].serial == dc[0].serial diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py index 5f191c0db..c00119522 100644 --- a/test/pathod/test_pathod.py +++ b/test/pathod/test_pathod.py @@ -57,7 +57,7 @@ class TestNotAfterConnect(tservers.DaemonTests): class TestCustomCert(tservers.DaemonTests): ssl = True ssloptions = dict( - certs=[(b"*", tutils.test_data.path("pathod/data/testkey.pem"))], + certs=[("*", tutils.test_data.path("pathod/data/testkey.pem"))], ) def test_connect(self):