diff --git a/libmproxy/certutils.py b/libmproxy/certutils.py index 43d14091e..c1e5d93e5 100644 --- a/libmproxy/certutils.py +++ b/libmproxy/certutils.py @@ -182,15 +182,9 @@ def dummy_cert(certdir, ca, commonname, sans): return certpath -def get_remote_cn(host, port): - addr = socket.gethostbyname(host) - s = ssl.get_server_certificate((addr, port)) - return parse_text_cert(s) - - class GeneralName(univ.Choice): # We are only interested in dNSNames. We use a default handler to ignore - # other types. + # other types. componentType = namedtype.NamedTypes( namedtype.NamedType('dNSName', char.IA5String().subtype( implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) @@ -198,27 +192,43 @@ class GeneralName(univ.Choice): ), ) + class GeneralNames(univ.SequenceOf): componentType = GeneralName() sizeSpec = univ.SequenceOf.sizeSpec + constraint.ValueSizeConstraint(1, 1024) +class SSLCert: + def __init__(self, pemtxt): + """ + Returns a (common name, [subject alternative names]) tuple. + """ + self.cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pemtxt) + + @property + def cn(self): + cn = None + for i in self.cert.get_subject().get_components(): + if i[0] == "CN": + cn = i[1] + return cn + + @property + def altnames(self): + altnames = [] + for i in range(self.cert.get_extension_count()): + ext = self.cert.get_extension(i) + if ext.get_short_name() == "subjectAltName": + dec = decode(ext.get_data(), asn1Spec=GeneralNames()) + for i in dec[0]: + altnames.append(i[0]) + return altnames + + + +def get_remote_cert(host, port): + addr = socket.gethostbyname(host) + s = ssl.get_server_certificate((addr, port)) + return SSLCert(s) -def parse_text_cert(txt): - """ - Returns a (common name, [subject alternative names]) tuple. - """ - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) - cn = None - for i in cert.get_subject().get_components(): - if i[0] == "CN": - cn = i[1] - altnames = [] - for i in range(cert.get_extension_count()): - ext = cert.get_extension(i) - if ext.get_short_name() == "subjectAltName": - dec = decode(ext.get_data(), asn1Spec=GeneralNames()) - for i in dec[0]: - altnames.append(i[0]) - return cn, altnames diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index a6ba790fd..31308e6f5 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -350,7 +350,9 @@ class ProxyHandler(SocketServer.StreamRequestHandler): else: sans = [] if self.config.upstream_cert: - host, sans = certutils.get_remote_cn(host, port) + cert = certutils.get_remote_cert(host, port) + sans = cert.altnames + host = cert.cn ret = certutils.dummy_cert(self.config.certdir, self.config.cacert, host, sans) time.sleep(self.config.cert_wait_time) if not ret: diff --git a/test/test_certutils.py b/test/test_certutils.py index 15e81f746..5ef5919e8 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -51,15 +51,13 @@ class udummy_cert(libpry.AutoTree): class uparse_text_cert(libpry.AutoTree): def test_simple(self): - c = file("data/text_cert", "r").read() - cn, san = certutils.parse_text_cert(c) - assert cn == "google.com" - assert len(san) == 436 + c = certutils.SSLCert(file("data/text_cert", "r").read()) + assert c.cn == "google.com" + assert len(c.altnames) == 436 - c = file("data/text_cert_2", "r").read() - cn, san = certutils.parse_text_cert(c) - assert cn == "www.inode.co.nz" - assert len(san) == 2 + c = certutils.SSLCert(file("data/text_cert_2", "r").read()) + assert c.cn == "www.inode.co.nz" + assert len(c.altnames) == 2 tests = [