clean up code

This commit is contained in:
Maximilian Hils 2014-10-09 00:15:39 +02:00
parent fdb6f5552d
commit 9ef84ccc1c
3 changed files with 41 additions and 40 deletions

View File

@ -114,9 +114,9 @@ def dummy_cert(privkey, cacert, commonname, sans):
class CertStoreEntry(object): class CertStoreEntry(object):
def __init__(self, cert, pkey=None, chain_file=None): def __init__(self, cert, privatekey, chain_file):
self.cert = cert self.cert = cert
self.pkey = pkey self.privatekey = privatekey
self.chain_file = chain_file self.chain_file = chain_file
@ -124,15 +124,15 @@ class CertStore:
""" """
Implements an in-memory certificate store. Implements an in-memory certificate store.
""" """
def __init__(self, default_pkey, default_ca, default_chain_file, dhparams=None): def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None):
self.default_pkey = default_pkey self.default_privatekey = default_privatekey
self.default_ca = default_ca self.default_ca = default_ca
self.default_chain_file = default_chain_file self.default_chain_file = default_chain_file
self.dhparams = dhparams self.dhparams = dhparams
self.certs = dict() self.certs = dict()
@classmethod @staticmethod
def load_dhparam(klass, path): def load_dhparam(path):
# netlib<=0.10 doesn't generate a dhparam file. # netlib<=0.10 doesn't generate a dhparam file.
# Create it now if neccessary. # Create it now if neccessary.
@ -163,8 +163,8 @@ class CertStore:
dh = cls.load_dhparam(dh_path) dh = cls.load_dhparam(dh_path)
return cls(key, ca, ca_path, dh) return cls(key, ca, ca_path, dh)
@classmethod @staticmethod
def create_store(cls, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
@ -173,32 +173,28 @@ class CertStore:
key, ca = create_ca(o=o, cn=cn, exp=expiry) key, ca = create_ca(o=o, cn=cn, exp=expiry)
# Dump the CA plus private key # Dump the CA plus private key
f = open(os.path.join(path, basename + "-ca.pem"), "wb") with open(os.path.join(path, basename + "-ca.pem"), "wb") as f:
f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key))
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Dump the certificate in PEM format # Dump the certificate in PEM format
f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb") with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f:
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Create a .cer file with the same contents for Android # Create a .cer file with the same contents for Android
f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb") with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f:
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Dump the certificate in PKCS12 format for Windows devices # Dump the certificate in PKCS12 format for Windows devices
f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb") with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f:
p12 = OpenSSL.crypto.PKCS12() p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca) p12.set_certificate(ca)
p12.set_privatekey(key) p12.set_privatekey(key)
f.write(p12.export()) f.write(p12.export())
f.close()
f = open(os.path.join(path, basename + "-dhparam.pem"), "wb") with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f:
f.write(DEFAULT_DHPARAM) f.write(DEFAULT_DHPARAM)
f.close()
return key, ca return key, ca
def add_cert_file(self, spec, path): def add_cert_file(self, spec, path):
@ -206,11 +202,11 @@ class CertStore:
raw = f.read() raw = f.read()
cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw))
try: try:
pkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) privatekey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
except Exception: except Exception:
pkey = None privatekey = self.default_privatekey
self.add_cert( self.add_cert(
CertStoreEntry(cert, pkey, path), CertStoreEntry(cert, privatekey, path),
spec spec
) )
@ -241,7 +237,7 @@ class CertStore:
def get_cert(self, commonname, sans): def get_cert(self, commonname, sans):
""" """
Returns an (cert, privkey) tuple. Returns an (cert, privkey, cert_chain) tuple.
commonname: Common name for the generated certificate. Must be a commonname: Common name for the generated certificate. Must be a
valid, plain-ASCII, IDNA-encoded domain name. valid, plain-ASCII, IDNA-encoded domain name.
@ -260,15 +256,20 @@ class CertStore:
if name: if name:
entry = self.certs[name] entry = self.certs[name]
else: else:
entry = CertStoreEntry(cert=dummy_cert(self.default_pkey, self.default_ca, commonname, sans)) entry = CertStoreEntry(
cert=dummy_cert(self.default_privatekey, self.default_ca, commonname, sans),
privatekey=self.default_privatekey,
chain_file=self.default_chain_file
)
self.certs[(commonname, tuple(sans))] = entry self.certs[(commonname, tuple(sans))] = entry
return entry.cert, (entry.pkey or self.default_pkey), (entry.chain_file or self.default_chain_file) return entry.cert, entry.privatekey, entry.chain_file
def gen_pkey(self, cert): def gen_pkey(self, cert):
# FIXME: We should do something with cert here?
from . import certffi from . import certffi
certffi.set_flags(self.default_pkey, 1) certffi.set_flags(self.default_privatekey, 1)
return self.default_pkey return self.default_privatekey
class _GeneralName(univ.Choice): class _GeneralName(univ.Choice):

View File

@ -52,7 +52,7 @@ class TestCertStore:
assert ca.get_cert("*.foo.com", []) assert ca.get_cert("*.foo.com", [])
r = ca.get_cert("*.foo.com", []) r = ca.get_cert("*.foo.com", [])
assert r[1] == ca.default_pkey assert r[1] == ca.default_privatekey
def test_add_cert(self): def test_add_cert(self):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
@ -98,7 +98,7 @@ class TestCertStore:
cert = ca1.get_cert("foo.com", []) cert = ca1.get_cert("foo.com", [])
assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1 assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1
finally: finally:
certffi.set_flags(ca2.default_pkey, 0) certffi.set_flags(ca2.default_privatekey, 0)
class TestDummyCert: class TestDummyCert:
@ -106,7 +106,7 @@ class TestDummyCert:
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
ca = certutils.CertStore.from_store(d, "test") ca = certutils.CertStore.from_store(d, "test")
r = certutils.dummy_cert( r = certutils.dummy_cert(
ca.default_pkey, ca.default_privatekey,
ca.default_ca, ca.default_ca,
"foo.com", "foo.com",
["one.com", "two.com", "*.three.com"] ["one.com", "two.com", "*.three.com"]

View File

@ -410,8 +410,8 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase):
ca1 = certutils.CertStore.from_store(d, "test2") ca1 = certutils.CertStore.from_store(d, "test2")
ca2 = certutils.CertStore.from_store(d, "test3") ca2 = certutils.CertStore.from_store(d, "test3")
cert, _, _ = ca1.get_cert("foo.com", []) cert, _, _ = ca1.get_cert("foo.com", [])
certffi.set_flags(ca2.default_pkey, 0) certffi.set_flags(ca2.default_privatekey, 0)
self.convert_to_ssl(cert, ca2.default_pkey) self.convert_to_ssl(cert, ca2.default_privatekey)
def test_privkey(self): def test_privkey(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))