From 9ef84ccc1cdd0d8da890ba012812c760e31f2fab Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 9 Oct 2014 00:15:39 +0200 Subject: [PATCH] clean up code --- netlib/certutils.py | 71 +++++++++++++++++++++--------------------- test/test_certutils.py | 6 ++-- test/test_tcp.py | 4 +-- 3 files changed, 41 insertions(+), 40 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index c9e6df26a..af6177d8c 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -114,9 +114,9 @@ def dummy_cert(privkey, cacert, commonname, sans): class CertStoreEntry(object): - def __init__(self, cert, pkey=None, chain_file=None): + def __init__(self, cert, privatekey, chain_file): self.cert = cert - self.pkey = pkey + self.privatekey = privatekey self.chain_file = chain_file @@ -124,15 +124,15 @@ class CertStore: """ Implements an in-memory certificate store. """ - def __init__(self, default_pkey, default_ca, default_chain_file, dhparams=None): - self.default_pkey = default_pkey + def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None): + self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file self.dhparams = dhparams self.certs = dict() - @classmethod - def load_dhparam(klass, path): + @staticmethod + def load_dhparam(path): # netlib<=0.10 doesn't generate a dhparam file. # Create it now if neccessary. @@ -163,8 +163,8 @@ class CertStore: dh = cls.load_dhparam(dh_path) return cls(key, ca, ca_path, dh) - @classmethod - def create_store(cls, path, basename, o=None, cn=None, expiry=DEFAULT_EXP): + @staticmethod + def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP): if not os.path.exists(path): os.makedirs(path) @@ -173,32 +173,28 @@ class CertStore: key, ca = create_ca(o=o, cn=cn, exp=expiry) # Dump the CA plus private key - f = open(os.path.join(path, basename + "-ca.pem"), "wb") - f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key)) + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PEM format - f = open(os.path.join(path, basename + "-ca-cert.pem"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Create a .cer file with the same contents for Android - f = open(os.path.join(path, basename + "-ca-cert.cer"), "wb") - f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) - f.close() + with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: + f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PKCS12 format for Windows devices - f = open(os.path.join(path, basename + "-ca-cert.p12"), "wb") - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - f.close() + with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: + p12 = OpenSSL.crypto.PKCS12() + p12.set_certificate(ca) + p12.set_privatekey(key) + f.write(p12.export()) + + with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: + f.write(DEFAULT_DHPARAM) - f = open(os.path.join(path, basename + "-dhparam.pem"), "wb") - f.write(DEFAULT_DHPARAM) - f.close() return key, ca def add_cert_file(self, spec, path): @@ -206,11 +202,11 @@ class CertStore: raw = f.read() cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)) try: - pkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) + privatekey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw) except Exception: - pkey = None + privatekey = self.default_privatekey self.add_cert( - CertStoreEntry(cert, pkey, path), + CertStoreEntry(cert, privatekey, path), spec ) @@ -241,7 +237,7 @@ class CertStore: def get_cert(self, commonname, sans): """ - Returns an (cert, privkey) tuple. + Returns an (cert, privkey, cert_chain) tuple. commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. @@ -260,15 +256,20 @@ class CertStore: if name: entry = self.certs[name] else: - entry = CertStoreEntry(cert=dummy_cert(self.default_pkey, self.default_ca, commonname, sans)) + entry = CertStoreEntry( + cert=dummy_cert(self.default_privatekey, self.default_ca, commonname, sans), + privatekey=self.default_privatekey, + chain_file=self.default_chain_file + ) self.certs[(commonname, tuple(sans))] = entry - return entry.cert, (entry.pkey or self.default_pkey), (entry.chain_file or self.default_chain_file) + return entry.cert, entry.privatekey, entry.chain_file def gen_pkey(self, cert): + # FIXME: We should do something with cert here? from . import certffi - certffi.set_flags(self.default_pkey, 1) - return self.default_pkey + certffi.set_flags(self.default_privatekey, 1) + return self.default_privatekey class _GeneralName(univ.Choice): diff --git a/test/test_certutils.py b/test/test_certutils.py index f68751ece..59c9dcd57 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -52,7 +52,7 @@ class TestCertStore: assert ca.get_cert("*.foo.com", []) r = ca.get_cert("*.foo.com", []) - assert r[1] == ca.default_pkey + assert r[1] == ca.default_privatekey def test_add_cert(self): with tutils.tmpdir() as d: @@ -98,7 +98,7 @@ class TestCertStore: cert = ca1.get_cert("foo.com", []) assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1 finally: - certffi.set_flags(ca2.default_pkey, 0) + certffi.set_flags(ca2.default_privatekey, 0) class TestDummyCert: @@ -106,7 +106,7 @@ class TestDummyCert: with tutils.tmpdir() as d: ca = certutils.CertStore.from_store(d, "test") r = certutils.dummy_cert( - ca.default_pkey, + ca.default_privatekey, ca.default_ca, "foo.com", ["one.com", "two.com", "*.three.com"] diff --git a/test/test_tcp.py b/test/test_tcp.py index 0eadac47f..bf3d46bf2 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -410,8 +410,8 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase): ca1 = certutils.CertStore.from_store(d, "test2") ca2 = certutils.CertStore.from_store(d, "test3") cert, _, _ = ca1.get_cert("foo.com", []) - certffi.set_flags(ca2.default_pkey, 0) - self.convert_to_ssl(cert, ca2.default_pkey) + certffi.set_flags(ca2.default_privatekey, 0) + self.convert_to_ssl(cert, ca2.default_privatekey) def test_privkey(self): c = tcp.TCPClient(("127.0.0.1", self.port))