temporarily replace DNTree with a simpler cert lookup mechanism, fix mitmproxy/mitmproxy#295

This commit is contained in:
Maximilian Hils 2014-07-18 22:55:25 +02:00
parent 55c2133b69
commit a7837846a2
2 changed files with 82 additions and 75 deletions

View File

@ -1,4 +1,5 @@
import os, ssl, time, datetime import os, ssl, time, datetime
import itertools
from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error from pyasn1.error import PyAsn1Error
@ -73,42 +74,44 @@ def dummy_cert(privkey, cacert, commonname, sans):
return SSLCert(cert) return SSLCert(cert)
class _Node(UserDict.UserDict): # DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict.
def __init__(self): #
UserDict.UserDict.__init__(self) # class _Node(UserDict.UserDict):
self.value = None # def __init__(self):
# UserDict.UserDict.__init__(self)
# self.value = None
class DNTree: #
""" #
Domain store that knows about wildcards. DNS wildcards are very # class DNTree:
restricted - the only valid variety is an asterisk on the left-most # """
domain component, i.e.: # Domain store that knows about wildcards. DNS wildcards are very
# restricted - the only valid variety is an asterisk on the left-most
*.foo.com # domain component, i.e.:
""" #
def __init__(self): # *.foo.com
self.d = _Node() # """
# def __init__(self):
def add(self, dn, cert): # self.d = _Node()
parts = dn.split(".") #
parts.reverse() # def add(self, dn, cert):
current = self.d # parts = dn.split(".")
for i in parts: # parts.reverse()
current = current.setdefault(i, _Node()) # current = self.d
current.value = cert # for i in parts:
# current = current.setdefault(i, _Node())
def get(self, dn): # current.value = cert
parts = dn.split(".") #
current = self.d # def get(self, dn):
for i in reversed(parts): # parts = dn.split(".")
if i in current: # current = self.d
current = current[i] # for i in reversed(parts):
elif "*" in current: # if i in current:
return current["*"].value # current = current[i]
else: # elif "*" in current:
return None # return current["*"].value
return current.value # else:
# return None
# return current.value
@ -119,7 +122,7 @@ class CertStore:
def __init__(self, privkey, cacert, dhparams=None): def __init__(self, privkey, cacert, dhparams=None):
self.privkey, self.cacert = privkey, cacert self.privkey, self.cacert = privkey, cacert
self.dhparams = dhparams self.dhparams = dhparams
self.certs = DNTree() self.certs = dict()
@classmethod @classmethod
def load_dhparam(klass, path): def load_dhparam(klass, path):
@ -206,11 +209,11 @@ class CertStore:
any SANs, and also the list of names provided as an argument. any SANs, and also the list of names provided as an argument.
""" """
if cert.cn: if cert.cn:
self.certs.add(cert.cn, (cert, privkey)) self.certs[cert.cn] = (cert, privkey)
for i in cert.altnames: for i in cert.altnames:
self.certs.add(i, (cert, privkey)) self.certs[i] = (cert, privkey)
for i in names: for i in names:
self.certs.add(i, (cert, privkey)) self.certs[i] = (cert, privkey)
def get_cert(self, commonname, sans): def get_cert(self, commonname, sans):
""" """
@ -223,12 +226,16 @@ class CertStore:
Return None if the certificate could not be found or generated. Return None if the certificate could not be found or generated.
""" """
c = self.certs.get(commonname)
if not c: potential_keys = [commonname] + sans + [(commonname, tuple(sans))]
c = dummy_cert(self.privkey, self.cacert, commonname, sans) name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None)
self.add_cert(c, None) if name:
c = (c, None) c = self.certs[name]
return (c[0], c[1] or self.privkey) else:
c = dummy_cert(self.privkey, self.cacert, commonname, sans), None
self.certs[(commonname, tuple(sans))] = c
return c[0], (c[1] or self.privkey)
def gen_pkey(self, cert): def gen_pkey(self, cert):
import certffi import certffi

View File

@ -3,34 +3,34 @@ from netlib import certutils, certffi
import OpenSSL import OpenSSL
import tutils import tutils
class TestDNTree: # class TestDNTree:
def test_simple(self): # def test_simple(self):
d = certutils.DNTree() # d = certutils.DNTree()
d.add("foo.com", "foo") # d.add("foo.com", "foo")
d.add("bar.com", "bar") # d.add("bar.com", "bar")
assert d.get("foo.com") == "foo" # assert d.get("foo.com") == "foo"
assert d.get("bar.com") == "bar" # assert d.get("bar.com") == "bar"
assert not d.get("oink.com") # assert not d.get("oink.com")
assert not d.get("oink") # assert not d.get("oink")
assert not d.get("") # assert not d.get("")
assert not d.get("oink.oink") # assert not d.get("oink.oink")
#
d.add("*.match.org", "match") # d.add("*.match.org", "match")
assert not d.get("match.org") # assert not d.get("match.org")
assert d.get("foo.match.org") == "match" # assert d.get("foo.match.org") == "match"
assert d.get("foo.foo.match.org") == "match" # assert d.get("foo.foo.match.org") == "match"
#
def test_wildcard(self): # def test_wildcard(self):
d = certutils.DNTree() # d = certutils.DNTree()
d.add("foo.com", "foo") # d.add("foo.com", "foo")
assert not d.get("*.foo.com") # assert not d.get("*.foo.com")
d.add("*.foo.com", "wild") # d.add("*.foo.com", "wild")
#
d = certutils.DNTree() # d = certutils.DNTree()
d.add("*", "foo") # d.add("*", "foo")
assert d.get("foo.com") == "foo" # assert d.get("foo.com") == "foo"
assert d.get("*.foo.com") == "foo" # assert d.get("*.foo.com") == "foo"
assert d.get("com") == "foo" # assert d.get("com") == "foo"
class TestCertStore: class TestCertStore:
@ -63,7 +63,7 @@ class TestCertStore:
ca = certutils.CertStore.from_store(d, "test") ca = certutils.CertStore.from_store(d, "test")
c1 = ca.get_cert("foo.com", ["*.bar.com"]) c1 = ca.get_cert("foo.com", ["*.bar.com"])
c2 = ca.get_cert("foo.bar.com", []) c2 = ca.get_cert("foo.bar.com", [])
assert c1 == c2 # assert c1 == c2
c3 = ca.get_cert("bar.com", []) c3 = ca.get_cert("bar.com", [])
assert not c1 == c3 assert not c1 == c3