This commit is contained in:
Maximilian Hils 2017-10-24 21:12:39 +02:00
parent 1f3fec2a3e
commit 4a6d838ecc
4 changed files with 28 additions and 22 deletions

View File

@ -4,6 +4,7 @@ import time
import datetime
import ipaddress
import sys
import typing
from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode
@ -122,6 +123,11 @@ class CertStoreEntry:
self.chain_file = chain_file
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans)
TCertId = typing.Union[TCustomCertId, TGeneratedCertId]
class CertStore:
"""
@ -139,7 +145,7 @@ class CertStore:
self.default_ca = default_ca
self.default_chain_file = default_chain_file
self.dhparams = dhparams
self.certs = dict()
self.certs = {} # type: typing.Dict[TCertId, CertStoreEntry]
self.expire_queue = []
def expire(self, entry):
@ -240,7 +246,7 @@ class CertStore:
return key, ca
def add_cert_file(self, spec, path):
def add_cert_file(self, spec: str, path: str) -> None:
with open(path, "rb") as f:
raw = f.read()
cert = SSLCert(
@ -255,10 +261,10 @@ class CertStore:
privatekey = self.default_privatekey
self.add_cert(
CertStoreEntry(cert, privatekey, path),
spec
spec.encode("idna")
)
def add_cert(self, entry, *names):
def add_cert(self, entry: CertStoreEntry, *names: bytes):
"""
Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument.
@ -271,21 +277,18 @@ class CertStore:
self.certs[i] = entry
@staticmethod
def asterisk_forms(dn):
if dn is None:
return []
def asterisk_forms(dn: bytes) -> typing.List[bytes]:
"""
Return all asterisk forms for a domain. For example, for www.example.com this will return
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
"""
parts = dn.split(b".")
parts.reverse()
curr_dn = b""
dn_forms = [b"*"]
for part in parts[:-1]:
curr_dn = b"." + part + curr_dn # .example.com
dn_forms.append(b"*" + curr_dn) # *.example.com
if parts[-1] != b"*":
dn_forms.append(parts[-1] + curr_dn)
return dn_forms
ret = [dn]
for i in range(1, len(parts)):
ret.append(b"*." + b".".join(parts[i:]))
return ret
def get_cert(self, commonname, sans):
def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes]):
"""
Returns an (cert, privkey, cert_chain) tuple.
@ -295,9 +298,12 @@ class CertStore:
sans: A list of Subject Alternate Names.
"""
potential_keys = self.asterisk_forms(commonname)
potential_keys = [] # type: typing.List[TCertId]
if commonname:
potential_keys.extend(self.asterisk_forms(commonname))
for s in sans:
potential_keys.extend(self.asterisk_forms(s))
potential_keys.append(b"*")
potential_keys.append((commonname, tuple(sans)))
name = next(

View File

@ -479,7 +479,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest):
ssl = True
ssloptions = pathod.SSLOptions(
certs=[
(b"*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
("*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
]
)
@ -1142,7 +1142,7 @@ class AddUpstreamCertsToClientChainMixin:
ssloptions = pathod.SSLOptions(
cn=b"example.mitmproxy.org",
certs=[
(b"example.mitmproxy.org", servercert)
("example.mitmproxy.org", servercert)
]
)

View File

@ -102,7 +102,7 @@ class TestCertStore:
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dcp = tmpdir.join("dc")
dcp.write(dc[0].to_pem())
ca1.add_cert_file(b"foo.com", str(dcp))
ca1.add_cert_file("foo.com", str(dcp))
ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial

View File

@ -57,7 +57,7 @@ class TestNotAfterConnect(tservers.DaemonTests):
class TestCustomCert(tservers.DaemonTests):
ssl = True
ssloptions = dict(
certs=[(b"*", tutils.test_data.path("pathod/data/testkey.pem"))],
certs=[("*", tutils.test_data.path("pathod/data/testkey.pem"))],
)
def test_connect(self):