Beef up CertStore, add DH params.

This commit is contained in:
Aldo Cortesi 2014-03-04 14:12:58 +13:00
parent d56f7fba80
commit 7c82418e0b
2 changed files with 96 additions and 94 deletions

View File

@ -5,23 +5,27 @@ from pyasn1.error import PyAsn1Error
import OpenSSL import OpenSSL
import tcp import tcp
default_exp = 62208000 # =24 * 60 * 60 * 720 DEFAULT_EXP = 62208000 # =24 * 60 * 60 * 720
default_o = "mitmproxy" # Generated with "openssl dhparam". It's too slow to generate this on startup.
default_cn = "mitmproxy" DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS-----
MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5
zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK
1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC
-----END DH PARAMETERS-----"""
def create_ca(o=default_o, cn=default_cn, exp=default_exp): def create_ca(o, cn, exp):
key = OpenSSL.crypto.PKey() key = OpenSSL.crypto.PKey()
key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024) key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024)
ca = OpenSSL.crypto.X509() cert = OpenSSL.crypto.X509()
ca.set_serial_number(int(time.time()*10000)) cert.set_serial_number(int(time.time()*10000))
ca.set_version(2) cert.set_version(2)
ca.get_subject().CN = cn cert.get_subject().CN = cn
ca.get_subject().O = o cert.get_subject().O = o
ca.gmtime_adj_notBefore(0) cert.gmtime_adj_notBefore(0)
ca.gmtime_adj_notAfter(exp) cert.gmtime_adj_notAfter(exp)
ca.set_issuer(ca.get_subject()) cert.set_issuer(cert.get_subject())
ca.set_pubkey(key) cert.set_pubkey(key)
ca.add_extensions([ cert.add_extensions([
OpenSSL.crypto.X509Extension("basicConstraints", True, OpenSSL.crypto.X509Extension("basicConstraints", True,
"CA:TRUE"), "CA:TRUE"),
OpenSSL.crypto.X509Extension("nsCertType", True, OpenSSL.crypto.X509Extension("nsCertType", True,
@ -32,80 +36,39 @@ def create_ca(o=default_o, cn=default_cn, exp=default_exp):
OpenSSL.crypto.X509Extension("keyUsage", False, OpenSSL.crypto.X509Extension("keyUsage", False,
"keyCertSign, cRLSign"), "keyCertSign, cRLSign"),
OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash",
subject=ca), subject=cert),
]) ])
ca.sign(key, "sha1") cert.sign(key, "sha1")
return key, ca return key, cert
def dummy_ca(path, o=default_o, cn=default_cn, exp=default_exp): def dummy_cert(pkey, cacert, commonname, sans):
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
if path.endswith(".pem"):
basename, _ = os.path.splitext(path)
basename = os.path.basename(basename)
else:
basename = os.path.basename(path)
key, ca = create_ca(o=o, cn=cn, exp=exp)
# Dump the CA plus private key
f = open(path, "wb")
f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key))
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Dump the certificate in PEM format
f = open(os.path.join(dirname, basename + "-cert.pem"), "wb")
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Create a .cer file with the same contents for Android
f = open(os.path.join(dirname, basename + "-cert.cer"), "wb")
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Dump the certificate in PKCS12 format for Windows devices
f = open(os.path.join(dirname, basename + "-cert.p12"), "wb")
p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca)
p12.set_privatekey(key)
f.write(p12.export())
f.close()
return True
def dummy_cert(ca, commonname, sans):
""" """
Generates and writes a certificate to fp. Generates a dummy certificate.
ca: Path to the certificate authority file, or None. pkey: CA private key
cacert: CA certificate
commonname: Common name for the generated certificate. commonname: Common name for the generated certificate.
sans: A list of Subject Alternate Names. sans: A list of Subject Alternate Names.
Returns cert path if operation succeeded, None if not. Returns cert if operation succeeded, None if not.
""" """
ss = [] ss = []
for i in sans: for i in sans:
ss.append("DNS: %s"%i) ss.append("DNS: %s"%i)
ss = ", ".join(ss) ss = ", ".join(ss)
raw = file(ca, "rb").read()
ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
cert = OpenSSL.crypto.X509() cert = OpenSSL.crypto.X509()
cert.gmtime_adj_notBefore(-3600*48) cert.gmtime_adj_notBefore(-3600*48)
cert.gmtime_adj_notAfter(60 * 60 * 24 * 30) cert.gmtime_adj_notAfter(60 * 60 * 24 * 30)
cert.set_issuer(ca.get_subject()) cert.set_issuer(cacert.get_subject())
cert.get_subject().CN = commonname cert.get_subject().CN = commonname
cert.set_serial_number(int(time.time()*10000)) cert.set_serial_number(int(time.time()*10000))
if ss: if ss:
cert.set_version(2) cert.set_version(2)
cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)]) cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", True, ss)])
cert.set_pubkey(ca.get_pubkey()) cert.set_pubkey(cacert.get_pubkey())
cert.sign(key, "sha1") cert.sign(pkey, "sha1")
return SSLCert(cert) return SSLCert(cert)
@ -113,9 +76,59 @@ class CertStore:
""" """
Implements an in-memory certificate store. Implements an in-memory certificate store.
""" """
def __init__(self, cacert): def __init__(self, pkey, cert):
self.pkey, self.cert = pkey, cert
self.certs = {} self.certs = {}
self.cacert = cacert
@classmethod
def from_store(klass, path, basename):
p = os.path.join(path, basename + "-ca.pem")
if not os.path.exists(p):
key, ca = klass.create_store(path, basename)
else:
p = os.path.join(path, basename + "-ca.pem")
raw = file(p, "rb").read()
ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
return klass(key, ca)
@classmethod
def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
if not os.path.exists(path):
os.makedirs(path)
o = o or basename
cn = cn or basename
key, ca = create_ca(o=o, cn=cn, exp=expiry)
# Dump the CA plus private key
f = open(os.path.join(path, basename + "-ca.pem"), "wb")
f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key))
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Dump the certificate in PEM format
f = open(os.path.join(path, basename + "-cert.pem"), "wb")
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Create a .cer file with the same contents for Android
f = open(os.path.join(path, basename + "-cert.cer"), "wb")
f.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
f.close()
# Dump the certificate in PKCS12 format for Windows devices
f = open(os.path.join(path, basename + "-cert.p12"), "wb")
p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca)
p12.set_privatekey(key)
f.write(p12.export())
f.close()
f = open(os.path.join(path, basename + "-dhparam.pem"), "wb")
f.write(DEFAULT_DHPARAM)
f.close()
return key, ca
def get_cert(self, commonname, sans): def get_cert(self, commonname, sans):
""" """
@ -130,7 +143,7 @@ class CertStore:
""" """
if commonname in self.certs: if commonname in self.certs:
return self.certs[commonname] return self.certs[commonname]
c = dummy_cert(self.cacert, commonname, sans) c = dummy_cert(self.pkey, self.cert, commonname, sans)
self.certs[commonname] = c self.certs[commonname] = c
return c return c

View File

@ -3,43 +3,32 @@ from netlib import certutils
import tutils import tutils
def test_dummy_ca():
with tutils.tmpdir() as d:
path = os.path.join(d, "foo/cert.cnf")
assert certutils.dummy_ca(path)
assert os.path.exists(path)
path = os.path.join(d, "foo/cert2.pem")
assert certutils.dummy_ca(path)
assert os.path.exists(path)
assert os.path.exists(os.path.join(d, "foo/cert2-cert.pem"))
assert os.path.exists(os.path.join(d, "foo/cert2-cert.p12"))
class TestCertStore: class TestCertStore:
def test_create_explicit(self): def test_create_explicit(self):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
ca = os.path.join(d, "ca") ca = certutils.CertStore.from_store(d, "test")
assert certutils.dummy_ca(ca) assert ca.get_cert("foo", [])
c = certutils.CertStore(ca)
ca2 = certutils.CertStore.from_store(d, "test")
assert ca2.get_cert("foo", [])
assert ca.cert.get_serial_number() == ca2.cert.get_serial_number()
def test_create_tmp(self): def test_create_tmp(self):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
ca = os.path.join(d, "ca") ca = certutils.CertStore.from_store(d, "test")
assert certutils.dummy_ca(ca) assert ca.get_cert("foo.com", [])
c = certutils.CertStore(ca) assert ca.get_cert("foo.com", [])
assert c.get_cert("foo.com", []) assert ca.get_cert("*.foo.com", [])
assert c.get_cert("foo.com", [])
assert c.get_cert("*.foo.com", [])
class TestDummyCert: class TestDummyCert:
def test_with_ca(self): def test_with_ca(self):
with tutils.tmpdir() as d: with tutils.tmpdir() as d:
cacert = os.path.join(d, "cacert") ca = certutils.CertStore.from_store(d, "test")
assert certutils.dummy_ca(cacert)
r = certutils.dummy_cert( r = certutils.dummy_cert(
cacert, ca.pkey,
ca.cert,
"foo.com", "foo.com",
["one.com", "two.com", "*.three.com"] ["one.com", "two.com", "*.three.com"]
) )