Much more sophisticated certificate store

- Handle wildcard lookup
- Handle lookup of SANs
- Provide hooks for registering override certs and keys for specific
domains (including wildcard specifications)
This commit is contained in:
Aldo Cortesi 2014-03-05 13:19:16 +13:00
parent 7c82418e0b
commit 0c3bc1cff2
2 changed files with 140 additions and 15 deletions

View File

@ -4,6 +4,7 @@ from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error from pyasn1.error import PyAsn1Error
import OpenSSL import OpenSSL
import tcp import tcp
import UserDict
DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720 DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720
# Generated with "openssl dhparam". It's too slow to generate this on startup. # Generated with "openssl dhparam". It's too slow to generate this on startup.
@ -42,11 +43,11 @@ def create_ca(o, cn, exp):
return key, cert return key, cert
def dummy_cert(pkey, cacert, commonname, sans): def dummy_cert(privkey, cacert, commonname, sans):
""" """
Generates a dummy certificate. Generates a dummy certificate.
pkey: CA private key privkey: CA private key
cacert: CA certificate cacert: CA certificate
commonname: Common name for the generated certificate. commonname: Common name for the generated certificate.
sans: A list of Subject Alternate Names. sans: A list of Subject Alternate Names.
@ -68,17 +69,55 @@ def dummy_cert(pkey, cacert, commonname, sans):
cert.set_version(2) cert.set_version(2)
cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)])
cert.set_pubkey(cacert.get_pubkey()) cert.set_pubkey(cacert.get_pubkey())
cert.sign(pkey, "sha1") cert.sign(privkey, "sha1")
return SSLCert(cert) return SSLCert(cert)
class _Node(UserDict.UserDict):
def __init__(self):
UserDict.UserDict.__init__(self)
self.value = None
class DNTree:
"""
Domain store that knows about wildcards. DNS wildcards are very
restricted - the only valid variety is an asterisk on the left-most
domain component, i.e.:
*.foo.com
"""
def __init__(self):
self.d = _Node()
def add(self, dn, cert):
parts = dn.split(".")
parts.reverse()
current = self.d
for i in parts:
current = current.setdefault(i, _Node())
current.value = cert
def get(self, dn):
parts = dn.split(".")
current = self.d
for i in reversed(parts):
if i in current:
current = current[i]
elif "*" in current:
return current["*"].value
else:
return None
return current.value
class CertStore: class CertStore:
""" """
Implements an in-memory certificate store. Implements an in-memory certificate store.
""" """
def __init__(self, pkey, cert): def __init__(self, privkey, cacert):
self.pkey, self.cert = pkey, cert self.privkey, self.cacert = privkey, cacert
self.certs = {} self.certs = DNTree()
@classmethod @classmethod
def from_store(klass, path, basename): def from_store(klass, path, basename):
@ -130,9 +169,29 @@ class CertStore:
f.close() f.close()
return key, ca return key, ca
def add_cert_file(self, commonname, path):
raw = file(path, "rb").read()
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
try:
privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
except Exception:
privkey = None
self.add_cert(SSLCert(cert), privkey, commonname)
def add_cert(self, cert, privkey, *names):
"""
Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument.
"""
self.certs.add(cert.cn, (cert, privkey))
for i in cert.altnames:
self.certs.add(i, (cert, privkey))
for i in names:
self.certs.add(i, (cert, privkey))
def get_cert(self, commonname, sans): def get_cert(self, commonname, sans):
""" """
Returns an SSLCert object. Returns an (cert, privkey) tuple.
commonname: Common name for the generated certificate. Must be a commonname: Common name for the generated certificate. Must be a
valid, plain-ASCII, IDNA-encoded domain name. valid, plain-ASCII, IDNA-encoded domain name.
@ -141,11 +200,12 @@ class CertStore:
Return None if the certificate could not be found or generated. Return None if the certificate could not be found or generated.
""" """
if commonname in self.certs: c = self.certs.get(commonname)
return self.certs[commonname] if not c:
c = dummy_cert(self.pkey, self.cert, commonname, sans) c = dummy_cert(self.privkey, self.cacert, commonname, sans)
self.certs[commonname] = c self.add_cert(c, None)
return c c = (c, None)
return (c[0], c[1] or self.privkey)
class _GeneralName(univ.Choice): class _GeneralName(univ.Choice):
@ -171,6 +231,9 @@ class SSLCert:
""" """
self.x509 = cert self.x509 = cert
def __eq__(self, other):
return self.digest("sha1") == other.digest("sha1")
@classmethod @classmethod
def from_pem(klass, txt): def from_pem(klass, txt):
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)

View File

@ -1,7 +1,37 @@
import os import os
from netlib import certutils from netlib import certutils
import OpenSSL
import tutils import tutils
class TestDNTree:
def test_simple(self):
d = certutils.DNTree()
d.add("foo.com", "foo")
d.add("bar.com", "bar")
assert d.get("foo.com") == "foo"
assert d.get("bar.com") == "bar"
assert not d.get("oink.com")
assert not d.get("oink")
assert not d.get("")
assert not d.get("oink.oink")
d.add("*.match.org", "match")
assert not d.get("match.org")
assert d.get("foo.match.org") == "match"
assert d.get("foo.foo.match.org") == "match"
def test_wildcard(self):
d = certutils.DNTree()
d.add("foo.com", "foo")
assert not d.get("*.foo.com")
d.add("*.foo.com", "wild")
d = certutils.DNTree()
d.add("*", "foo")
assert d.get("foo.com") == "foo"
assert d.get("*.foo.com") == "foo"
assert d.get("com") == "foo"
class TestCertStore: class TestCertStore:
def test_create_explicit(self): def test_create_explicit(self):
@ -12,7 +42,7 @@ class TestCertStore:
ca2 = certutils.CertStore.from_store(d, "test") ca2 = certutils.CertStore.from_store(d, "test")
assert ca2.get_cert("foo", []) assert ca2.get_cert("foo", [])
assert ca.cert.get_serial_number() == ca2.cert.get_serial_number() assert ca.cacert.get_serial_number() == ca2.cacert.get_serial_number()
def test_create_tmp(self): def test_create_tmp(self):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
@ -21,14 +51,46 @@ class TestCertStore:
assert ca.get_cert("foo.com", []) assert ca.get_cert("foo.com", [])
assert ca.get_cert("*.foo.com", []) assert ca.get_cert("*.foo.com", [])
r = ca.get_cert("*.foo.com", [])
assert r[1] == ca.privkey
def test_add_cert(self):
with tutils.tmpdir() as d:
ca = certutils.CertStore.from_store(d, "test")
def test_sans(self):
with tutils.tmpdir() as d:
ca = certutils.CertStore.from_store(d, "test")
c1 = ca.get_cert("foo.com", ["*.bar.com"])
c2 = ca.get_cert("foo.bar.com", [])
assert c1 == c2
c3 = ca.get_cert("bar.com", [])
assert not c1 == c3
def test_overrides(self):
with tutils.tmpdir() as d:
ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
assert not ca1.cacert.get_serial_number() == ca2.cacert.get_serial_number()
dc = ca2.get_cert("foo.com", [])
dcp = os.path.join(d, "dc")
f = open(dcp, "wb")
f.write(dc[0].to_pem())
f.close()
ca1.add_cert_file("foo.com", dcp)
ret = ca1.get_cert("foo.com", [])
assert ret[0].serial == dc[0].serial
class TestDummyCert: class TestDummyCert:
def test_with_ca(self): def test_with_ca(self):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
ca = certutils.CertStore.from_store(d, "test") ca = certutils.CertStore.from_store(d, "test")
r = certutils.dummy_cert( r = certutils.dummy_cert(
ca.pkey, ca.privkey,
ca.cert, ca.cacert,
"foo.com", "foo.com",
["one.com", "two.com", "*.three.com"] ["one.com", "two.com", "*.three.com"]
) )