Merge pull request #2606 from mhils/issue-2563

Fix #2563
This commit is contained in:
Maximilian Hils 2017-10-25 10:20:09 +02:00 committed by GitHub
commit fdd6bd8277
4 changed files with 28 additions and 62 deletions

View File

@ -4,6 +4,7 @@ import time
import datetime import datetime
import ipaddress import ipaddress
import sys import sys
import typing
from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode from pyasn1.codec.der.decoder import decode
@ -114,46 +115,6 @@ def dummy_cert(privkey, cacert, commonname, sans):
return SSLCert(cert) return SSLCert(cert)
# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict.
#
# class _Node(UserDict.UserDict):
# def __init__(self):
# UserDict.UserDict.__init__(self)
# self.value = None
#
#
# class DNTree:
# """
# Domain store that knows about wildcards. DNS wildcards are very
# restricted - the only valid variety is an asterisk on the left-most
# domain component, i.e.:
#
# *.foo.com
# """
# def __init__(self):
# self.d = _Node()
#
# def add(self, dn, cert):
# parts = dn.split(".")
# parts.reverse()
# current = self.d
# for i in parts:
# current = current.setdefault(i, _Node())
# current.value = cert
#
# def get(self, dn):
# parts = dn.split(".")
# current = self.d
# for i in reversed(parts):
# if i in current:
# current = current[i]
# elif "*" in current:
# return current["*"].value
# else:
# return None
# return current.value
class CertStoreEntry: class CertStoreEntry:
def __init__(self, cert, privatekey, chain_file): def __init__(self, cert, privatekey, chain_file):
@ -162,6 +123,11 @@ class CertStoreEntry:
self.chain_file = chain_file self.chain_file = chain_file
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans)
TCertId = typing.Union[TCustomCertId, TGeneratedCertId]
class CertStore: class CertStore:
""" """
@ -179,7 +145,7 @@ class CertStore:
self.default_ca = default_ca self.default_ca = default_ca
self.default_chain_file = default_chain_file self.default_chain_file = default_chain_file
self.dhparams = dhparams self.dhparams = dhparams
self.certs = dict() self.certs = {} # type: typing.Dict[TCertId, CertStoreEntry]
self.expire_queue = [] self.expire_queue = []
def expire(self, entry): def expire(self, entry):
@ -280,7 +246,7 @@ class CertStore:
return key, ca return key, ca
def add_cert_file(self, spec, path): def add_cert_file(self, spec: str, path: str) -> None:
with open(path, "rb") as f: with open(path, "rb") as f:
raw = f.read() raw = f.read()
cert = SSLCert( cert = SSLCert(
@ -295,10 +261,10 @@ class CertStore:
privatekey = self.default_privatekey privatekey = self.default_privatekey
self.add_cert( self.add_cert(
CertStoreEntry(cert, privatekey, path), CertStoreEntry(cert, privatekey, path),
spec spec.encode("idna")
) )
def add_cert(self, entry, *names): def add_cert(self, entry: CertStoreEntry, *names: bytes):
""" """
Adds a cert to the certstore. We register the CN in the cert plus Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument. any SANs, and also the list of names provided as an argument.
@ -311,21 +277,18 @@ class CertStore:
self.certs[i] = entry self.certs[i] = entry
@staticmethod @staticmethod
def asterisk_forms(dn): def asterisk_forms(dn: bytes) -> typing.List[bytes]:
if dn is None: """
return [] Return all asterisk forms for a domain. For example, for www.example.com this will return
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
"""
parts = dn.split(b".") parts = dn.split(b".")
parts.reverse() ret = [dn]
curr_dn = b"" for i in range(1, len(parts)):
dn_forms = [b"*"] ret.append(b"*." + b".".join(parts[i:]))
for part in parts[:-1]: return ret
curr_dn = b"." + part + curr_dn # .example.com
dn_forms.append(b"*" + curr_dn) # *.example.com
if parts[-1] != b"*":
dn_forms.append(parts[-1] + curr_dn)
return dn_forms
def get_cert(self, commonname, sans): def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes]):
""" """
Returns an (cert, privkey, cert_chain) tuple. Returns an (cert, privkey, cert_chain) tuple.
@ -335,9 +298,12 @@ class CertStore:
sans: A list of Subject Alternate Names. sans: A list of Subject Alternate Names.
""" """
potential_keys = self.asterisk_forms(commonname) potential_keys = [] # type: typing.List[TCertId]
if commonname:
potential_keys.extend(self.asterisk_forms(commonname))
for s in sans: for s in sans:
potential_keys.extend(self.asterisk_forms(s)) potential_keys.extend(self.asterisk_forms(s))
potential_keys.append(b"*")
potential_keys.append((commonname, tuple(sans))) potential_keys.append((commonname, tuple(sans)))
name = next( name = next(

View File

@ -479,7 +479,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest):
ssl = True ssl = True
ssloptions = pathod.SSLOptions( ssloptions = pathod.SSLOptions(
certs=[ certs=[
(b"*", tutils.test_data.path("mitmproxy/data/no_common_name.pem")) ("*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
] ]
) )
@ -1142,7 +1142,7 @@ class AddUpstreamCertsToClientChainMixin:
ssloptions = pathod.SSLOptions( ssloptions = pathod.SSLOptions(
cn=b"example.mitmproxy.org", cn=b"example.mitmproxy.org",
certs=[ certs=[
(b"example.mitmproxy.org", servercert) ("example.mitmproxy.org", servercert)
] ]
) )

View File

@ -102,7 +102,7 @@ class TestCertStore:
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dcp = tmpdir.join("dc") dcp = tmpdir.join("dc")
dcp.write(dc[0].to_pem()) dcp.write(dc[0].to_pem())
ca1.add_cert_file(b"foo.com", str(dcp)) ca1.add_cert_file("foo.com", str(dcp))
ret = ca1.get_cert(b"foo.com", []) ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial assert ret[0].serial == dc[0].serial

View File

@ -57,7 +57,7 @@ class TestNotAfterConnect(tservers.DaemonTests):
class TestCustomCert(tservers.DaemonTests): class TestCustomCert(tservers.DaemonTests):
ssl = True ssl = True
ssloptions = dict( ssloptions = dict(
certs=[(b"*", tutils.test_data.path("pathod/data/testkey.pem"))], certs=[("*", tutils.test_data.path("pathod/data/testkey.pem"))],
) )
def test_connect(self): def test_connect(self):