mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
CertStore: add support for cert chains
This commit is contained in:
parent
274688172d
commit
fdb6f5552d
@ -113,13 +113,21 @@ def dummy_cert(privkey, cacert, commonname, sans):
|
|||||||
# return current.value
|
# return current.value
|
||||||
|
|
||||||
|
|
||||||
|
class CertStoreEntry(object):
|
||||||
|
def __init__(self, cert, pkey=None, chain_file=None):
|
||||||
|
self.cert = cert
|
||||||
|
self.pkey = pkey
|
||||||
|
self.chain_file = chain_file
|
||||||
|
|
||||||
|
|
||||||
class CertStore:
|
class CertStore:
|
||||||
"""
|
"""
|
||||||
Implements an in-memory certificate store.
|
Implements an in-memory certificate store.
|
||||||
"""
|
"""
|
||||||
def __init__(self, privkey, cacert, dhparams=None):
|
def __init__(self, default_pkey, default_ca, default_chain_file, dhparams=None):
|
||||||
self.privkey, self.cacert = privkey, cacert
|
self.default_pkey = default_pkey
|
||||||
|
self.default_ca = default_ca
|
||||||
|
self.default_chain_file = default_chain_file
|
||||||
self.dhparams = dhparams
|
self.dhparams = dhparams
|
||||||
self.certs = dict()
|
self.certs = dict()
|
||||||
|
|
||||||
@ -142,21 +150,21 @@ class CertStore:
|
|||||||
return dh
|
return dh
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_store(klass, path, basename):
|
def from_store(cls, path, basename):
|
||||||
p = os.path.join(path, basename + "-ca.pem")
|
ca_path = os.path.join(path, basename + "-ca.pem")
|
||||||
if not os.path.exists(p):
|
if not os.path.exists(ca_path):
|
||||||
key, ca = klass.create_store(path, basename)
|
key, ca = cls.create_store(path, basename)
|
||||||
else:
|
else:
|
||||||
p = os.path.join(path, basename + "-ca.pem")
|
with open(ca_path, "rb") as f:
|
||||||
raw = file(p, "rb").read()
|
raw = f.read()
|
||||||
ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
|
ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
|
||||||
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
|
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
|
||||||
dhp = os.path.join(path, basename + "-dhparam.pem")
|
dh_path = os.path.join(path, basename + "-dhparam.pem")
|
||||||
dh = klass.load_dhparam(dhp)
|
dh = cls.load_dhparam(dh_path)
|
||||||
return klass(key, ca, dh)
|
return cls(key, ca, ca_path, dh)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_store(klass, path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
|
def create_store(cls, path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
|
||||||
@ -194,25 +202,29 @@ class CertStore:
|
|||||||
return key, ca
|
return key, ca
|
||||||
|
|
||||||
def add_cert_file(self, spec, path):
|
def add_cert_file(self, spec, path):
|
||||||
raw = file(path, "rb").read()
|
with open(path, "rb") as f:
|
||||||
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw)
|
raw = f.read()
|
||||||
|
cert = SSLCert(OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, raw))
|
||||||
try:
|
try:
|
||||||
privkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
|
pkey = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw)
|
||||||
except Exception:
|
except Exception:
|
||||||
privkey = None
|
pkey = None
|
||||||
self.add_cert(SSLCert(cert), privkey, spec)
|
self.add_cert(
|
||||||
|
CertStoreEntry(cert, pkey, path),
|
||||||
|
spec
|
||||||
|
)
|
||||||
|
|
||||||
def add_cert(self, cert, privkey, *names):
|
def add_cert(self, entry, *names):
|
||||||
"""
|
"""
|
||||||
Adds a cert to the certstore. We register the CN in the cert plus
|
Adds a cert to the certstore. We register the CN in the cert plus
|
||||||
any SANs, and also the list of names provided as an argument.
|
any SANs, and also the list of names provided as an argument.
|
||||||
"""
|
"""
|
||||||
if cert.cn:
|
if entry.cert.cn:
|
||||||
self.certs[cert.cn] = (cert, privkey)
|
self.certs[entry.cert.cn] = entry
|
||||||
for i in cert.altnames:
|
for i in entry.cert.altnames:
|
||||||
self.certs[i] = (cert, privkey)
|
self.certs[i] = entry
|
||||||
for i in names:
|
for i in names:
|
||||||
self.certs[i] = (cert, privkey)
|
self.certs[i] = entry
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def asterisk_forms(dn):
|
def asterisk_forms(dn):
|
||||||
@ -246,17 +258,17 @@ class CertStore:
|
|||||||
|
|
||||||
name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None)
|
name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None)
|
||||||
if name:
|
if name:
|
||||||
c = self.certs[name]
|
entry = self.certs[name]
|
||||||
else:
|
else:
|
||||||
c = dummy_cert(self.privkey, self.cacert, commonname, sans), None
|
entry = CertStoreEntry(cert=dummy_cert(self.default_pkey, self.default_ca, commonname, sans))
|
||||||
self.certs[(commonname, tuple(sans))] = c
|
self.certs[(commonname, tuple(sans))] = entry
|
||||||
|
|
||||||
return c[0], (c[1] or self.privkey)
|
return entry.cert, (entry.pkey or self.default_pkey), (entry.chain_file or self.default_chain_file)
|
||||||
|
|
||||||
def gen_pkey(self, cert):
|
def gen_pkey(self, cert):
|
||||||
from . import certffi
|
from . import certffi
|
||||||
certffi.set_flags(self.privkey, 1)
|
certffi.set_flags(self.default_pkey, 1)
|
||||||
return self.privkey
|
return self.default_pkey
|
||||||
|
|
||||||
|
|
||||||
class _GeneralName(univ.Choice):
|
class _GeneralName(univ.Choice):
|
||||||
|
@ -345,7 +345,7 @@ class BaseHandler(_Connection):
|
|||||||
|
|
||||||
def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None,
|
def _create_ssl_context(self, cert, key, method=SSLv23_METHOD, options=None,
|
||||||
handle_sni=None, request_client_cert=None, cipher_list=None,
|
handle_sni=None, request_client_cert=None, cipher_list=None,
|
||||||
dhparams=None, ca_file=None):
|
dhparams=None, chain_file=None):
|
||||||
"""
|
"""
|
||||||
cert: A certutils.SSLCert object.
|
cert: A certutils.SSLCert object.
|
||||||
|
|
||||||
@ -377,8 +377,8 @@ class BaseHandler(_Connection):
|
|||||||
ctx = SSL.Context(method)
|
ctx = SSL.Context(method)
|
||||||
if not options is None:
|
if not options is None:
|
||||||
ctx.set_options(options)
|
ctx.set_options(options)
|
||||||
if ca_file:
|
if chain_file:
|
||||||
ctx.load_verify_locations(ca_file)
|
ctx.load_verify_locations(chain_file)
|
||||||
if cipher_list:
|
if cipher_list:
|
||||||
try:
|
try:
|
||||||
ctx.set_cipher_list(cipher_list)
|
ctx.set_cipher_list(cipher_list)
|
||||||
|
@ -42,7 +42,7 @@ class TestCertStore:
|
|||||||
ca2 = certutils.CertStore.from_store(d, "test")
|
ca2 = certutils.CertStore.from_store(d, "test")
|
||||||
assert ca2.get_cert("foo", [])
|
assert ca2.get_cert("foo", [])
|
||||||
|
|
||||||
assert ca.cacert.get_serial_number() == ca2.cacert.get_serial_number()
|
assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
|
||||||
|
|
||||||
def test_create_tmp(self):
|
def test_create_tmp(self):
|
||||||
with tutils.tmpdir() as d:
|
with tutils.tmpdir() as d:
|
||||||
@ -52,7 +52,7 @@ class TestCertStore:
|
|||||||
assert ca.get_cert("*.foo.com", [])
|
assert ca.get_cert("*.foo.com", [])
|
||||||
|
|
||||||
r = ca.get_cert("*.foo.com", [])
|
r = ca.get_cert("*.foo.com", [])
|
||||||
assert r[1] == ca.privkey
|
assert r[1] == ca.default_pkey
|
||||||
|
|
||||||
def test_add_cert(self):
|
def test_add_cert(self):
|
||||||
with tutils.tmpdir() as d:
|
with tutils.tmpdir() as d:
|
||||||
@ -71,14 +71,14 @@ class TestCertStore:
|
|||||||
with tutils.tmpdir() as d:
|
with tutils.tmpdir() as d:
|
||||||
ca = certutils.CertStore.from_store(d, "test")
|
ca = certutils.CertStore.from_store(d, "test")
|
||||||
_ = ca.get_cert("foo.com", ["*.bar.com"])
|
_ = ca.get_cert("foo.com", ["*.bar.com"])
|
||||||
cert, key = ca.get_cert("foo.bar.com", ["*.baz.com"])
|
cert, key, chain_file = ca.get_cert("foo.bar.com", ["*.baz.com"])
|
||||||
assert "*.baz.com" in cert.altnames
|
assert "*.baz.com" in cert.altnames
|
||||||
|
|
||||||
def test_overrides(self):
|
def test_overrides(self):
|
||||||
with tutils.tmpdir() as d:
|
with tutils.tmpdir() as d:
|
||||||
ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
|
ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
|
||||||
ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
|
ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
|
||||||
assert not ca1.cacert.get_serial_number() == ca2.cacert.get_serial_number()
|
assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
|
||||||
|
|
||||||
dc = ca2.get_cert("foo.com", [])
|
dc = ca2.get_cert("foo.com", [])
|
||||||
dcp = os.path.join(d, "dc")
|
dcp = os.path.join(d, "dc")
|
||||||
@ -98,7 +98,7 @@ class TestCertStore:
|
|||||||
cert = ca1.get_cert("foo.com", [])
|
cert = ca1.get_cert("foo.com", [])
|
||||||
assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1
|
assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1
|
||||||
finally:
|
finally:
|
||||||
certffi.set_flags(ca2.privkey, 0)
|
certffi.set_flags(ca2.default_pkey, 0)
|
||||||
|
|
||||||
|
|
||||||
class TestDummyCert:
|
class TestDummyCert:
|
||||||
@ -106,8 +106,8 @@ class TestDummyCert:
|
|||||||
with tutils.tmpdir() as d:
|
with tutils.tmpdir() as d:
|
||||||
ca = certutils.CertStore.from_store(d, "test")
|
ca = certutils.CertStore.from_store(d, "test")
|
||||||
r = certutils.dummy_cert(
|
r = certutils.dummy_cert(
|
||||||
ca.privkey,
|
ca.default_pkey,
|
||||||
ca.cacert,
|
ca.default_ca,
|
||||||
"foo.com",
|
"foo.com",
|
||||||
["one.com", "two.com", "*.three.com"]
|
["one.com", "two.com", "*.three.com"]
|
||||||
)
|
)
|
||||||
|
@ -393,7 +393,7 @@ class TestPrivkeyGen(test.ServerTestBase):
|
|||||||
with tutils.tmpdir() as d:
|
with tutils.tmpdir() as d:
|
||||||
ca1 = certutils.CertStore.from_store(d, "test2")
|
ca1 = certutils.CertStore.from_store(d, "test2")
|
||||||
ca2 = certutils.CertStore.from_store(d, "test3")
|
ca2 = certutils.CertStore.from_store(d, "test3")
|
||||||
cert, _ = ca1.get_cert("foo.com", [])
|
cert, _, _ = ca1.get_cert("foo.com", [])
|
||||||
key = ca2.gen_pkey(cert)
|
key = ca2.gen_pkey(cert)
|
||||||
self.convert_to_ssl(cert, key)
|
self.convert_to_ssl(cert, key)
|
||||||
|
|
||||||
@ -409,9 +409,9 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase):
|
|||||||
with tutils.tmpdir() as d:
|
with tutils.tmpdir() as d:
|
||||||
ca1 = certutils.CertStore.from_store(d, "test2")
|
ca1 = certutils.CertStore.from_store(d, "test2")
|
||||||
ca2 = certutils.CertStore.from_store(d, "test3")
|
ca2 = certutils.CertStore.from_store(d, "test3")
|
||||||
cert, _ = ca1.get_cert("foo.com", [])
|
cert, _, _ = ca1.get_cert("foo.com", [])
|
||||||
certffi.set_flags(ca2.privkey, 0)
|
certffi.set_flags(ca2.default_pkey, 0)
|
||||||
self.convert_to_ssl(cert, ca2.privkey)
|
self.convert_to_ssl(cert, ca2.default_pkey)
|
||||||
|
|
||||||
def test_privkey(self):
|
def test_privkey(self):
|
||||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
|
Loading…
Reference in New Issue
Block a user