mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
fix #2563
This commit is contained in:
parent
1f3fec2a3e
commit
4a6d838ecc
@ -4,6 +4,7 @@ import time
|
|||||||
import datetime
|
import datetime
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import sys
|
import sys
|
||||||
|
import typing
|
||||||
|
|
||||||
from pyasn1.type import univ, constraint, char, namedtype, tag
|
from pyasn1.type import univ, constraint, char, namedtype, tag
|
||||||
from pyasn1.codec.der.decoder import decode
|
from pyasn1.codec.der.decoder import decode
|
||||||
@ -122,6 +123,11 @@ class CertStoreEntry:
|
|||||||
self.chain_file = chain_file
|
self.chain_file = chain_file
|
||||||
|
|
||||||
|
|
||||||
|
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
|
||||||
|
TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans)
|
||||||
|
TCertId = typing.Union[TCustomCertId, TGeneratedCertId]
|
||||||
|
|
||||||
|
|
||||||
class CertStore:
|
class CertStore:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -139,7 +145,7 @@ class CertStore:
|
|||||||
self.default_ca = default_ca
|
self.default_ca = default_ca
|
||||||
self.default_chain_file = default_chain_file
|
self.default_chain_file = default_chain_file
|
||||||
self.dhparams = dhparams
|
self.dhparams = dhparams
|
||||||
self.certs = dict()
|
self.certs = {} # type: typing.Dict[TCertId, CertStoreEntry]
|
||||||
self.expire_queue = []
|
self.expire_queue = []
|
||||||
|
|
||||||
def expire(self, entry):
|
def expire(self, entry):
|
||||||
@ -240,7 +246,7 @@ class CertStore:
|
|||||||
|
|
||||||
return key, ca
|
return key, ca
|
||||||
|
|
||||||
def add_cert_file(self, spec, path):
|
def add_cert_file(self, spec: str, path: str) -> None:
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
raw = f.read()
|
raw = f.read()
|
||||||
cert = SSLCert(
|
cert = SSLCert(
|
||||||
@ -255,10 +261,10 @@ class CertStore:
|
|||||||
privatekey = self.default_privatekey
|
privatekey = self.default_privatekey
|
||||||
self.add_cert(
|
self.add_cert(
|
||||||
CertStoreEntry(cert, privatekey, path),
|
CertStoreEntry(cert, privatekey, path),
|
||||||
spec
|
spec.encode("idna")
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_cert(self, entry, *names):
|
def add_cert(self, entry: CertStoreEntry, *names: bytes):
|
||||||
"""
|
"""
|
||||||
Adds a cert to the certstore. We register the CN in the cert plus
|
Adds a cert to the certstore. We register the CN in the cert plus
|
||||||
any SANs, and also the list of names provided as an argument.
|
any SANs, and also the list of names provided as an argument.
|
||||||
@ -271,21 +277,18 @@ class CertStore:
|
|||||||
self.certs[i] = entry
|
self.certs[i] = entry
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def asterisk_forms(dn):
|
def asterisk_forms(dn: bytes) -> typing.List[bytes]:
|
||||||
if dn is None:
|
"""
|
||||||
return []
|
Return all asterisk forms for a domain. For example, for www.example.com this will return
|
||||||
|
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
|
||||||
|
"""
|
||||||
parts = dn.split(b".")
|
parts = dn.split(b".")
|
||||||
parts.reverse()
|
ret = [dn]
|
||||||
curr_dn = b""
|
for i in range(1, len(parts)):
|
||||||
dn_forms = [b"*"]
|
ret.append(b"*." + b".".join(parts[i:]))
|
||||||
for part in parts[:-1]:
|
return ret
|
||||||
curr_dn = b"." + part + curr_dn # .example.com
|
|
||||||
dn_forms.append(b"*" + curr_dn) # *.example.com
|
|
||||||
if parts[-1] != b"*":
|
|
||||||
dn_forms.append(parts[-1] + curr_dn)
|
|
||||||
return dn_forms
|
|
||||||
|
|
||||||
def get_cert(self, commonname, sans):
|
def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes]):
|
||||||
"""
|
"""
|
||||||
Returns an (cert, privkey, cert_chain) tuple.
|
Returns an (cert, privkey, cert_chain) tuple.
|
||||||
|
|
||||||
@ -295,9 +298,12 @@ class CertStore:
|
|||||||
sans: A list of Subject Alternate Names.
|
sans: A list of Subject Alternate Names.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
potential_keys = self.asterisk_forms(commonname)
|
potential_keys = [] # type: typing.List[TCertId]
|
||||||
|
if commonname:
|
||||||
|
potential_keys.extend(self.asterisk_forms(commonname))
|
||||||
for s in sans:
|
for s in sans:
|
||||||
potential_keys.extend(self.asterisk_forms(s))
|
potential_keys.extend(self.asterisk_forms(s))
|
||||||
|
potential_keys.append(b"*")
|
||||||
potential_keys.append((commonname, tuple(sans)))
|
potential_keys.append((commonname, tuple(sans)))
|
||||||
|
|
||||||
name = next(
|
name = next(
|
||||||
|
@ -479,7 +479,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest):
|
|||||||
ssl = True
|
ssl = True
|
||||||
ssloptions = pathod.SSLOptions(
|
ssloptions = pathod.SSLOptions(
|
||||||
certs=[
|
certs=[
|
||||||
(b"*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
|
("*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1142,7 +1142,7 @@ class AddUpstreamCertsToClientChainMixin:
|
|||||||
ssloptions = pathod.SSLOptions(
|
ssloptions = pathod.SSLOptions(
|
||||||
cn=b"example.mitmproxy.org",
|
cn=b"example.mitmproxy.org",
|
||||||
certs=[
|
certs=[
|
||||||
(b"example.mitmproxy.org", servercert)
|
("example.mitmproxy.org", servercert)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -102,7 +102,7 @@ class TestCertStore:
|
|||||||
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
|
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
|
||||||
dcp = tmpdir.join("dc")
|
dcp = tmpdir.join("dc")
|
||||||
dcp.write(dc[0].to_pem())
|
dcp.write(dc[0].to_pem())
|
||||||
ca1.add_cert_file(b"foo.com", str(dcp))
|
ca1.add_cert_file("foo.com", str(dcp))
|
||||||
|
|
||||||
ret = ca1.get_cert(b"foo.com", [])
|
ret = ca1.get_cert(b"foo.com", [])
|
||||||
assert ret[0].serial == dc[0].serial
|
assert ret[0].serial == dc[0].serial
|
||||||
|
@ -57,7 +57,7 @@ class TestNotAfterConnect(tservers.DaemonTests):
|
|||||||
class TestCustomCert(tservers.DaemonTests):
|
class TestCustomCert(tservers.DaemonTests):
|
||||||
ssl = True
|
ssl = True
|
||||||
ssloptions = dict(
|
ssloptions = dict(
|
||||||
certs=[(b"*", tutils.test_data.path("pathod/data/testkey.pem"))],
|
certs=[("*", tutils.test_data.path("pathod/data/testkey.pem"))],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_connect(self):
|
def test_connect(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user