This commit is contained in:
Maximilian Hils 2017-10-24 21:12:39 +02:00
parent 1f3fec2a3e
commit 4a6d838ecc
4 changed files with 28 additions and 22 deletions

View File

@ -4,6 +4,7 @@ import time
import datetime import datetime
import ipaddress import ipaddress
import sys import sys
import typing
from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode from pyasn1.codec.der.decoder import decode
@ -122,6 +123,11 @@ class CertStoreEntry:
self.chain_file = chain_file self.chain_file = chain_file
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans)
TCertId = typing.Union[TCustomCertId, TGeneratedCertId]
class CertStore: class CertStore:
""" """
@ -139,7 +145,7 @@ class CertStore:
self.default_ca = default_ca self.default_ca = default_ca
self.default_chain_file = default_chain_file self.default_chain_file = default_chain_file
self.dhparams = dhparams self.dhparams = dhparams
self.certs = dict() self.certs = {} # type: typing.Dict[TCertId, CertStoreEntry]
self.expire_queue = [] self.expire_queue = []
def expire(self, entry): def expire(self, entry):
@ -240,7 +246,7 @@ class CertStore:
return key, ca return key, ca
def add_cert_file(self, spec, path): def add_cert_file(self, spec: str, path: str) -> None:
with open(path, "rb") as f: with open(path, "rb") as f:
raw = f.read() raw = f.read()
cert = SSLCert( cert = SSLCert(
@ -255,10 +261,10 @@ class CertStore:
privatekey = self.default_privatekey privatekey = self.default_privatekey
self.add_cert( self.add_cert(
CertStoreEntry(cert, privatekey, path), CertStoreEntry(cert, privatekey, path),
spec spec.encode("idna")
) )
def add_cert(self, entry, *names): def add_cert(self, entry: CertStoreEntry, *names: bytes):
""" """
Adds a cert to the certstore. We register the CN in the cert plus Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument. any SANs, and also the list of names provided as an argument.
@ -271,21 +277,18 @@ class CertStore:
self.certs[i] = entry self.certs[i] = entry
@staticmethod @staticmethod
def asterisk_forms(dn): def asterisk_forms(dn: bytes) -> typing.List[bytes]:
if dn is None: """
return [] Return all asterisk forms for a domain. For example, for www.example.com this will return
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
"""
parts = dn.split(b".") parts = dn.split(b".")
parts.reverse() ret = [dn]
curr_dn = b"" for i in range(1, len(parts)):
dn_forms = [b"*"] ret.append(b"*." + b".".join(parts[i:]))
for part in parts[:-1]: return ret
curr_dn = b"." + part + curr_dn # .example.com
dn_forms.append(b"*" + curr_dn) # *.example.com
if parts[-1] != b"*":
dn_forms.append(parts[-1] + curr_dn)
return dn_forms
def get_cert(self, commonname, sans): def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes]):
""" """
Returns an (cert, privkey, cert_chain) tuple. Returns an (cert, privkey, cert_chain) tuple.
@ -295,9 +298,12 @@ class CertStore:
sans: A list of Subject Alternate Names. sans: A list of Subject Alternate Names.
""" """
potential_keys = self.asterisk_forms(commonname) potential_keys = [] # type: typing.List[TCertId]
if commonname:
potential_keys.extend(self.asterisk_forms(commonname))
for s in sans: for s in sans:
potential_keys.extend(self.asterisk_forms(s)) potential_keys.extend(self.asterisk_forms(s))
potential_keys.append(b"*")
potential_keys.append((commonname, tuple(sans))) potential_keys.append((commonname, tuple(sans)))
name = next( name = next(

View File

@ -479,7 +479,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest):
ssl = True ssl = True
ssloptions = pathod.SSLOptions( ssloptions = pathod.SSLOptions(
certs=[ certs=[
(b"*", tutils.test_data.path("mitmproxy/data/no_common_name.pem")) ("*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
] ]
) )
@ -1142,7 +1142,7 @@ class AddUpstreamCertsToClientChainMixin:
ssloptions = pathod.SSLOptions( ssloptions = pathod.SSLOptions(
cn=b"example.mitmproxy.org", cn=b"example.mitmproxy.org",
certs=[ certs=[
(b"example.mitmproxy.org", servercert) ("example.mitmproxy.org", servercert)
] ]
) )

View File

@ -102,7 +102,7 @@ class TestCertStore:
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dcp = tmpdir.join("dc") dcp = tmpdir.join("dc")
dcp.write(dc[0].to_pem()) dcp.write(dc[0].to_pem())
ca1.add_cert_file(b"foo.com", str(dcp)) ca1.add_cert_file("foo.com", str(dcp))
ret = ca1.get_cert(b"foo.com", []) ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial assert ret[0].serial == dc[0].serial

View File

@ -57,7 +57,7 @@ class TestNotAfterConnect(tservers.DaemonTests):
class TestCustomCert(tservers.DaemonTests): class TestCustomCert(tservers.DaemonTests):
ssl = True ssl = True
ssloptions = dict( ssloptions = dict(
certs=[(b"*", tutils.test_data.path("pathod/data/testkey.pem"))], certs=[("*", tutils.test_data.path("pathod/data/testkey.pem"))],
) )
def test_connect(self): def test_connect(self):