mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
Merge branch 'master' into stream
Conflicts: netlib/http.py
This commit is contained in:
commit
254a686235
@ -1,4 +1,5 @@
|
||||
import os, ssl, time, datetime
|
||||
import itertools
|
||||
from pyasn1.type import univ, constraint, char, namedtype, tag
|
||||
from pyasn1.codec.der.decoder import decode
|
||||
from pyasn1.error import PyAsn1Error
|
||||
@ -73,42 +74,44 @@ def dummy_cert(privkey, cacert, commonname, sans):
|
||||
return SSLCert(cert)
|
||||
|
||||
|
||||
class _Node(UserDict.UserDict):
|
||||
def __init__(self):
|
||||
UserDict.UserDict.__init__(self)
|
||||
self.value = None
|
||||
|
||||
|
||||
class DNTree:
|
||||
"""
|
||||
Domain store that knows about wildcards. DNS wildcards are very
|
||||
restricted - the only valid variety is an asterisk on the left-most
|
||||
domain component, i.e.:
|
||||
|
||||
*.foo.com
|
||||
"""
|
||||
def __init__(self):
|
||||
self.d = _Node()
|
||||
|
||||
def add(self, dn, cert):
|
||||
parts = dn.split(".")
|
||||
parts.reverse()
|
||||
current = self.d
|
||||
for i in parts:
|
||||
current = current.setdefault(i, _Node())
|
||||
current.value = cert
|
||||
|
||||
def get(self, dn):
|
||||
parts = dn.split(".")
|
||||
current = self.d
|
||||
for i in reversed(parts):
|
||||
if i in current:
|
||||
current = current[i]
|
||||
elif "*" in current:
|
||||
return current["*"].value
|
||||
else:
|
||||
return None
|
||||
return current.value
|
||||
# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict.
|
||||
#
|
||||
# class _Node(UserDict.UserDict):
|
||||
# def __init__(self):
|
||||
# UserDict.UserDict.__init__(self)
|
||||
# self.value = None
|
||||
#
|
||||
#
|
||||
# class DNTree:
|
||||
# """
|
||||
# Domain store that knows about wildcards. DNS wildcards are very
|
||||
# restricted - the only valid variety is an asterisk on the left-most
|
||||
# domain component, i.e.:
|
||||
#
|
||||
# *.foo.com
|
||||
# """
|
||||
# def __init__(self):
|
||||
# self.d = _Node()
|
||||
#
|
||||
# def add(self, dn, cert):
|
||||
# parts = dn.split(".")
|
||||
# parts.reverse()
|
||||
# current = self.d
|
||||
# for i in parts:
|
||||
# current = current.setdefault(i, _Node())
|
||||
# current.value = cert
|
||||
#
|
||||
# def get(self, dn):
|
||||
# parts = dn.split(".")
|
||||
# current = self.d
|
||||
# for i in reversed(parts):
|
||||
# if i in current:
|
||||
# current = current[i]
|
||||
# elif "*" in current:
|
||||
# return current["*"].value
|
||||
# else:
|
||||
# return None
|
||||
# return current.value
|
||||
|
||||
|
||||
|
||||
@ -119,7 +122,7 @@ class CertStore:
|
||||
def __init__(self, privkey, cacert, dhparams=None):
|
||||
self.privkey, self.cacert = privkey, cacert
|
||||
self.dhparams = dhparams
|
||||
self.certs = DNTree()
|
||||
self.certs = dict()
|
||||
|
||||
@classmethod
|
||||
def load_dhparam(klass, path):
|
||||
@ -206,11 +209,24 @@ class CertStore:
|
||||
any SANs, and also the list of names provided as an argument.
|
||||
"""
|
||||
if cert.cn:
|
||||
self.certs.add(cert.cn, (cert, privkey))
|
||||
self.certs[cert.cn] = (cert, privkey)
|
||||
for i in cert.altnames:
|
||||
self.certs.add(i, (cert, privkey))
|
||||
self.certs[i] = (cert, privkey)
|
||||
for i in names:
|
||||
self.certs.add(i, (cert, privkey))
|
||||
self.certs[i] = (cert, privkey)
|
||||
|
||||
@staticmethod
|
||||
def asterisk_forms(dn):
|
||||
parts = dn.split(".")
|
||||
parts.reverse()
|
||||
curr_dn = ""
|
||||
dn_forms = ["*"]
|
||||
for part in parts[:-1]:
|
||||
curr_dn = "." + part + curr_dn # .example.com
|
||||
dn_forms.append("*" + curr_dn) # *.example.com
|
||||
if parts[-1] != "*":
|
||||
dn_forms.append(parts[-1] + curr_dn)
|
||||
return dn_forms
|
||||
|
||||
def get_cert(self, commonname, sans):
|
||||
"""
|
||||
@ -223,12 +239,20 @@ class CertStore:
|
||||
|
||||
Return None if the certificate could not be found or generated.
|
||||
"""
|
||||
c = self.certs.get(commonname)
|
||||
if not c:
|
||||
c = dummy_cert(self.privkey, self.cacert, commonname, sans)
|
||||
self.add_cert(c, None)
|
||||
c = (c, None)
|
||||
return (c[0], c[1] or self.privkey)
|
||||
|
||||
potential_keys = self.asterisk_forms(commonname)
|
||||
for s in sans:
|
||||
potential_keys.extend(self.asterisk_forms(s))
|
||||
potential_keys.append((commonname, tuple(sans)))
|
||||
|
||||
name = next(itertools.ifilter(lambda key: key in self.certs, potential_keys), None)
|
||||
if name:
|
||||
c = self.certs[name]
|
||||
else:
|
||||
c = dummy_cert(self.privkey, self.cacert, commonname, sans), None
|
||||
self.certs[(commonname, tuple(sans))] = c
|
||||
|
||||
return c[0], (c[1] or self.privkey)
|
||||
|
||||
def gen_pkey(self, cert):
|
||||
import certffi
|
||||
|
@ -288,6 +288,11 @@ def parse_response_line(line):
|
||||
def read_response(rfile, request_method, body_size_limit, include_body=True):
|
||||
"""
|
||||
Return an (httpversion, code, msg, headers, content) tuple.
|
||||
|
||||
By default, both response header and body are read.
|
||||
If include_body=False is specified, content may be one of the following:
|
||||
- None, if the response is technically allowed to have a response body
|
||||
- "", if the response must not have a response body (e.g. it's a response to a HEAD request)
|
||||
"""
|
||||
line = rfile.readline()
|
||||
if line == "\r\n" or line == "\n": # Possible leftover from previous message
|
||||
@ -368,7 +373,7 @@ def expected_http_body_size(headers, is_request, request_method, response_code):
|
||||
- -1, if all data should be read until end of stream.
|
||||
"""
|
||||
|
||||
# Determine response size according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3
|
||||
# Determine response size according to http://tools.ietf.org/html/rfc7230#section-3.3
|
||||
if request_method:
|
||||
request_method = request_method.upper()
|
||||
|
||||
@ -390,4 +395,4 @@ def expected_http_body_size(headers, is_request, request_method, response_code):
|
||||
raise HttpError(400 if is_request else 502, "Invalid content-length header: %s" % headers["content-length"])
|
||||
if is_request:
|
||||
return 0
|
||||
return -1
|
||||
return -1
|
||||
|
@ -3,34 +3,34 @@ from netlib import certutils, certffi
|
||||
import OpenSSL
|
||||
import tutils
|
||||
|
||||
class TestDNTree:
|
||||
def test_simple(self):
|
||||
d = certutils.DNTree()
|
||||
d.add("foo.com", "foo")
|
||||
d.add("bar.com", "bar")
|
||||
assert d.get("foo.com") == "foo"
|
||||
assert d.get("bar.com") == "bar"
|
||||
assert not d.get("oink.com")
|
||||
assert not d.get("oink")
|
||||
assert not d.get("")
|
||||
assert not d.get("oink.oink")
|
||||
|
||||
d.add("*.match.org", "match")
|
||||
assert not d.get("match.org")
|
||||
assert d.get("foo.match.org") == "match"
|
||||
assert d.get("foo.foo.match.org") == "match"
|
||||
|
||||
def test_wildcard(self):
|
||||
d = certutils.DNTree()
|
||||
d.add("foo.com", "foo")
|
||||
assert not d.get("*.foo.com")
|
||||
d.add("*.foo.com", "wild")
|
||||
|
||||
d = certutils.DNTree()
|
||||
d.add("*", "foo")
|
||||
assert d.get("foo.com") == "foo"
|
||||
assert d.get("*.foo.com") == "foo"
|
||||
assert d.get("com") == "foo"
|
||||
# class TestDNTree:
|
||||
# def test_simple(self):
|
||||
# d = certutils.DNTree()
|
||||
# d.add("foo.com", "foo")
|
||||
# d.add("bar.com", "bar")
|
||||
# assert d.get("foo.com") == "foo"
|
||||
# assert d.get("bar.com") == "bar"
|
||||
# assert not d.get("oink.com")
|
||||
# assert not d.get("oink")
|
||||
# assert not d.get("")
|
||||
# assert not d.get("oink.oink")
|
||||
#
|
||||
# d.add("*.match.org", "match")
|
||||
# assert not d.get("match.org")
|
||||
# assert d.get("foo.match.org") == "match"
|
||||
# assert d.get("foo.foo.match.org") == "match"
|
||||
#
|
||||
# def test_wildcard(self):
|
||||
# d = certutils.DNTree()
|
||||
# d.add("foo.com", "foo")
|
||||
# assert not d.get("*.foo.com")
|
||||
# d.add("*.foo.com", "wild")
|
||||
#
|
||||
# d = certutils.DNTree()
|
||||
# d.add("*", "foo")
|
||||
# assert d.get("foo.com") == "foo"
|
||||
# assert d.get("*.foo.com") == "foo"
|
||||
# assert d.get("com") == "foo"
|
||||
|
||||
|
||||
class TestCertStore:
|
||||
@ -63,10 +63,17 @@ class TestCertStore:
|
||||
ca = certutils.CertStore.from_store(d, "test")
|
||||
c1 = ca.get_cert("foo.com", ["*.bar.com"])
|
||||
c2 = ca.get_cert("foo.bar.com", [])
|
||||
assert c1 == c2
|
||||
# assert c1 == c2
|
||||
c3 = ca.get_cert("bar.com", [])
|
||||
assert not c1 == c3
|
||||
|
||||
def test_sans_change(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca = certutils.CertStore.from_store(d, "test")
|
||||
_ = ca.get_cert("foo.com", ["*.bar.com"])
|
||||
cert, key = ca.get_cert("foo.bar.com", ["*.baz.com"])
|
||||
assert "*.baz.com" in cert.altnames
|
||||
|
||||
def test_overrides(self):
|
||||
with tutils.tmpdir() as d:
|
||||
ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
|
||||
|
@ -1,5 +1,5 @@
|
||||
import cStringIO, Queue, time, socket, random
|
||||
from netlib import tcp, certutils, test
|
||||
from netlib import tcp, certutils, test, certffi
|
||||
import mock
|
||||
import tutils
|
||||
from OpenSSL import SSL
|
||||
@ -419,7 +419,7 @@ class TestPrivkeyGenNoFlags(test.ServerTestBase):
|
||||
def test_privkey(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
tutils.raises("unexpected eof", c.convert_to_ssl)
|
||||
tutils.raises("sslv3 alert handshake failure", c.convert_to_ssl)
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user