Merge pull request #2606 from mhils/issue-2563

Fix #2563
This commit is contained in:
Maximilian Hils 2017-10-25 10:20:09 +02:00 committed by GitHub
commit fdd6bd8277
4 changed files with 28 additions and 62 deletions

View File

@ -4,6 +4,7 @@ import time
import datetime
import ipaddress
import sys
import typing
from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode
@ -114,46 +115,6 @@ def dummy_cert(privkey, cacert, commonname, sans):
return SSLCert(cert)
# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict.
#
# class _Node(UserDict.UserDict):
# def __init__(self):
# UserDict.UserDict.__init__(self)
# self.value = None
#
#
# class DNTree:
# """
# Domain store that knows about wildcards. DNS wildcards are very
# restricted - the only valid variety is an asterisk on the left-most
# domain component, i.e.:
#
# *.foo.com
# """
# def __init__(self):
# self.d = _Node()
#
# def add(self, dn, cert):
# parts = dn.split(".")
# parts.reverse()
# current = self.d
# for i in parts:
# current = current.setdefault(i, _Node())
# current.value = cert
#
# def get(self, dn):
# parts = dn.split(".")
# current = self.d
# for i in reversed(parts):
# if i in current:
# current = current[i]
# elif "*" in current:
# return current["*"].value
# else:
# return None
# return current.value
class CertStoreEntry:
def __init__(self, cert, privatekey, chain_file):
@ -162,6 +123,11 @@ class CertStoreEntry:
self.chain_file = chain_file
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans)
TCertId = typing.Union[TCustomCertId, TGeneratedCertId]
class CertStore:
"""
@ -179,7 +145,7 @@ class CertStore:
self.default_ca = default_ca
self.default_chain_file = default_chain_file
self.dhparams = dhparams
self.certs = dict()
self.certs = {} # type: typing.Dict[TCertId, CertStoreEntry]
self.expire_queue = []
def expire(self, entry):
@ -280,7 +246,7 @@ class CertStore:
return key, ca
def add_cert_file(self, spec, path):
def add_cert_file(self, spec: str, path: str) -> None:
with open(path, "rb") as f:
raw = f.read()
cert = SSLCert(
@ -295,10 +261,10 @@ class CertStore:
privatekey = self.default_privatekey
self.add_cert(
CertStoreEntry(cert, privatekey, path),
spec
spec.encode("idna")
)
def add_cert(self, entry, *names):
def add_cert(self, entry: CertStoreEntry, *names: bytes):
"""
Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument.
@ -311,21 +277,18 @@ class CertStore:
self.certs[i] = entry
@staticmethod
def asterisk_forms(dn):
if dn is None:
return []
def asterisk_forms(dn: bytes) -> typing.List[bytes]:
"""
Return all asterisk forms for a domain. For example, for www.example.com this will return
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
"""
parts = dn.split(b".")
parts.reverse()
curr_dn = b""
dn_forms = [b"*"]
for part in parts[:-1]:
curr_dn = b"." + part + curr_dn # .example.com
dn_forms.append(b"*" + curr_dn) # *.example.com
if parts[-1] != b"*":
dn_forms.append(parts[-1] + curr_dn)
return dn_forms
ret = [dn]
for i in range(1, len(parts)):
ret.append(b"*." + b".".join(parts[i:]))
return ret
def get_cert(self, commonname, sans):
def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes]):
"""
Returns an (cert, privkey, cert_chain) tuple.
@ -335,9 +298,12 @@ class CertStore:
sans: A list of Subject Alternate Names.
"""
potential_keys = self.asterisk_forms(commonname)
potential_keys = [] # type: typing.List[TCertId]
if commonname:
potential_keys.extend(self.asterisk_forms(commonname))
for s in sans:
potential_keys.extend(self.asterisk_forms(s))
potential_keys.append(b"*")
potential_keys.append((commonname, tuple(sans)))
name = next(

View File

@ -479,7 +479,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest):
ssl = True
ssloptions = pathod.SSLOptions(
certs=[
(b"*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
("*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
]
)
@ -1142,7 +1142,7 @@ class AddUpstreamCertsToClientChainMixin:
ssloptions = pathod.SSLOptions(
cn=b"example.mitmproxy.org",
certs=[
(b"example.mitmproxy.org", servercert)
("example.mitmproxy.org", servercert)
]
)

View File

@ -102,7 +102,7 @@ class TestCertStore:
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dcp = tmpdir.join("dc")
dcp.write(dc[0].to_pem())
ca1.add_cert_file(b"foo.com", str(dcp))
ca1.add_cert_file("foo.com", str(dcp))
ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial

View File

@ -57,7 +57,7 @@ class TestNotAfterConnect(tservers.DaemonTests):
class TestCustomCert(tservers.DaemonTests):
ssl = True
ssloptions = dict(
certs=[(b"*", tutils.test_data.path("pathod/data/testkey.pem"))],
certs=[("*", tutils.test_data.path("pathod/data/testkey.pem"))],
)
def test_connect(self):