CertStore: add support for cert chains

This commit is contained in:
Maximilian Hils 2014-10-08 20:46:30 +02:00
parent 274688172d
commit fdb6f5552d
4 changed files with 55 additions and 43 deletions

View File

@ -113,13 +113,21 @@ def dummy_cert(privkey, cacert, commonname, sans):
# return current.value # return current.value
class CertStoreEntry(object):
def __init__(self, cert, pkey=None, chain_file=None):
self.cert = cert
self.pkey = pkey
self.chain_file = chain_file
class CertStore: class CertStore:
""" """
Implements an in-memory certificate store. Implements an in-memory certificate store.
""" """
def __init__(self, privkey, cacert, dhparams=None): def __init__(self, default_pkey, default_ca, default_chain_file, dhparams=None):
self.privkey, self.cacert = privkey, cacert self.default_pkey = default_pkey
self.default_ca = default_ca
self.default_chain_file = default_chain_file
self.dhparams = dhparams self.dhparams = dhparams
self.certs = dict() self.certs = dict()
@ -142,21 +150,21 @@ class CertStore:
return dh return dh
@classmethod @classmethod
def from_store(klass, path, basename): def from_store(cls, path, basename):
p = os.path.join(path, basename + "-ca.pem") ca_path = os.path.join(path, basename + "-ca.pem")
if not os.path.exists(p): if not os.path.exists(ca_path):
key, ca = klass.create_store(path, basename) key, ca = cls.create_store(path, basename)
else: else:
p = os.path.join(path, basename + "-ca.pem") with open(ca_path, "rb") as f:
raw = file(p, "rb").read() raw = f.read()
ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
dhp = os.path.join(path, basename + "-dhparam.pem") dh_path = os.path.join(path, basename + "-dhparam.pem")
dh = klass.load_dhparam(dhp) dh = cls.load_dhparam(dh_path)
return klass(key, ca, dh) return cls(key, ca, ca_path, dh)
@classmethod @classmethod
def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): def create_store(cls, path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
@ -194,25 +202,29 @@ class CertStore:
return key, ca return key, ca
def add_cert_file(self, spec, path): def add_cert_file(self, spec, path):
raw = file(path, "rb").read() with open(path, "rb") as f:
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw) raw = f.read()
cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw))
try: try:
privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) pkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
except Exception: except Exception:
privkey = None pkey = None
self.add_cert(SSLCert(cert), privkey, spec) self.add_cert(
CertStoreEntry(cert, pkey, path),
spec
)
def add_cert(self, cert, privkey, *names): def add_cert(self, entry, *names):
""" """
Adds a cert to the certstore. We register the CN in the cert plus Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument. any SANs, and also the list of names provided as an argument.
""" """
if cert.cn: if entry.cert.cn:
self.certs[cert.cn] = (cert, privkey) self.certs[entry.cert.cn] = entry
for i in cert.altnames: for i in entry.cert.altnames:
self.certs[i] = (cert, privkey) self.certs[i] = entry
for i in names: for i in names:
self.certs[i] = (cert, privkey) self.certs[i] = entry
@staticmethod @staticmethod
def asterisk_forms(dn): def asterisk_forms(dn):
@ -246,17 +258,17 @@ class CertStore:
name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None) name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None)
if name: if name:
c = self.certs[name] entry = self.certs[name]
else: else:
c = dummy_cert(self.privkey, self.cacert, commonname, sans), None entry = CertStoreEntry(cert=dummy_cert(self.default_pkey, self.default_ca, commonname, sans))
self.certs[(commonname, tuple(sans))] = c self.certs[(commonname, tuple(sans))] = entry
return c[0], (c[1] or self.privkey) return entry.cert, (entry.pkey or self.default_pkey), (entry.chain_file or self.default_chain_file)
def gen_pkey(self, cert): def gen_pkey(self, cert):
from . import certffi from . import certffi
certffi.set_flags(self.privkey, 1) certffi.set_flags(self.default_pkey, 1)
return self.privkey return self.default_pkey
class _GeneralName(univ.Choice): class _GeneralName(univ.Choice):

View File

@ -345,7 +345,7 @@ class BaseHandler(_Connection):
def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None, def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None,
handle_sni=None, request_client_cert=None, cipher_list=None, handle_sni=None, request_client_cert=None, cipher_list=None,
dhparams=None, ca_file=None): dhparams=None, chain_file=None):
""" """
cert: A certutils.SSLCert object. cert: A certutils.SSLCert object.
@ -377,8 +377,8 @@ class BaseHandler(_Connection):
ctx = SSL.Context(method) ctx = SSL.Context(method)
if not options is None: if not options is None:
ctx.set_options(options) ctx.set_options(options)
if ca_file: if chain_file:
ctx.load_verify_locations(ca_file) ctx.load_verify_locations(chain_file)
if cipher_list: if cipher_list:
try: try:
ctx.set_cipher_list(cipher_list) ctx.set_cipher_list(cipher_list)

View File

@ -42,7 +42,7 @@ class TestCertStore:
ca2 = certutils.CertStore.from_store(d, "test") ca2 = certutils.CertStore.from_store(d, "test")
assert ca2.get_cert("foo", []) assert ca2.get_cert("foo", [])
assert ca.cacert.get_serial_number() == ca2.cacert.get_serial_number() assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
def test_create_tmp(self): def test_create_tmp(self):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
@ -52,7 +52,7 @@ class TestCertStore:
assert ca.get_cert("*.foo.com", []) assert ca.get_cert("*.foo.com", [])
r = ca.get_cert("*.foo.com", []) r = ca.get_cert("*.foo.com", [])
assert r[1] == ca.privkey assert r[1] == ca.default_pkey
def test_add_cert(self): def test_add_cert(self):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
@ -71,14 +71,14 @@ class TestCertStore:
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
ca = certutils.CertStore.from_store(d, "test") ca = certutils.CertStore.from_store(d, "test")
_ = ca.get_cert("foo.com", ["*.bar.com"]) _ = ca.get_cert("foo.com", ["*.bar.com"])
cert, key = ca.get_cert("foo.bar.com", ["*.baz.com"]) cert, key, chain_file = ca.get_cert("foo.bar.com", ["*.baz.com"])
assert "*.baz.com" in cert.altnames assert "*.baz.com" in cert.altnames
def test_overrides(self): def test_overrides(self):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test") ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test") ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
assert not ca1.cacert.get_serial_number() == ca2.cacert.get_serial_number() assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
dc = ca2.get_cert("foo.com", []) dc = ca2.get_cert("foo.com", [])
dcp = os.path.join(d, "dc") dcp = os.path.join(d, "dc")
@ -98,7 +98,7 @@ class TestCertStore:
cert = ca1.get_cert("foo.com", []) cert = ca1.get_cert("foo.com", [])
assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1 assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1
finally: finally:
certffi.set_flags(ca2.privkey, 0) certffi.set_flags(ca2.default_pkey, 0)
class TestDummyCert: class TestDummyCert:
@ -106,8 +106,8 @@ class TestDummyCert:
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
ca = certutils.CertStore.from_store(d, "test") ca = certutils.CertStore.from_store(d, "test")
r = certutils.dummy_cert( r = certutils.dummy_cert(
ca.privkey, ca.default_pkey,
ca.cacert, ca.default_ca,
"foo.com", "foo.com",
["one.com", "two.com", "*.three.com"] ["one.com", "two.com", "*.three.com"]
) )

View File

@ -393,7 +393,7 @@ class TestPrivkeyGen(test.ServerTestBase):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
ca1 = certutils.CertStore.from_store(d, "test2") ca1 = certutils.CertStore.from_store(d, "test2")
ca2 = certutils.CertStore.from_store(d, "test3") ca2 = certutils.CertStore.from_store(d, "test3")
cert, _ = ca1.get_cert("foo.com", []) cert, _, _ = ca1.get_cert("foo.com", [])
key = ca2.gen_pkey(cert) key = ca2.gen_pkey(cert)
self.convert_to_ssl(cert, key) self.convert_to_ssl(cert, key)
@ -409,9 +409,9 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
ca1 = certutils.CertStore.from_store(d, "test2") ca1 = certutils.CertStore.from_store(d, "test2")
ca2 = certutils.CertStore.from_store(d, "test3") ca2 = certutils.CertStore.from_store(d, "test3")
cert, _ = ca1.get_cert("foo.com", []) cert, _, _ = ca1.get_cert("foo.com", [])
certffi.set_flags(ca2.privkey, 0) certffi.set_flags(ca2.default_pkey, 0)
self.convert_to_ssl(cert, ca2.privkey) self.convert_to_ssl(cert, ca2.default_pkey)
def test_privkey(self): def test_privkey(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))