mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-29 19:08:44 +00:00
fix #2563
This commit is contained in:
parent
1f3fec2a3e
commit
4a6d838ecc
@ -4,6 +4,7 @@ import time
|
||||
import datetime
|
||||
import ipaddress
|
||||
import sys
|
||||
import typing
|
||||
|
||||
from pyasn1.type import univ, constraint, char, namedtype, tag
|
||||
from pyasn1.codec.der.decoder import decode
|
||||
@ -122,6 +123,11 @@ class CertStoreEntry:
|
||||
self.chain_file = chain_file
|
||||
|
||||
|
||||
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
|
||||
TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans)
|
||||
TCertId = typing.Union[TCustomCertId, TGeneratedCertId]
|
||||
|
||||
|
||||
class CertStore:
|
||||
|
||||
"""
|
||||
@ -139,7 +145,7 @@ class CertStore:
|
||||
self.default_ca = default_ca
|
||||
self.default_chain_file = default_chain_file
|
||||
self.dhparams = dhparams
|
||||
self.certs = dict()
|
||||
self.certs = {} # type: typing.Dict[TCertId, CertStoreEntry]
|
||||
self.expire_queue = []
|
||||
|
||||
def expire(self, entry):
|
||||
@ -240,7 +246,7 @@ class CertStore:
|
||||
|
||||
return key, ca
|
||||
|
||||
def add_cert_file(self, spec, path):
|
||||
def add_cert_file(self, spec: str, path: str) -> None:
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read()
|
||||
cert = SSLCert(
|
||||
@ -255,10 +261,10 @@ class CertStore:
|
||||
privatekey = self.default_privatekey
|
||||
self.add_cert(
|
||||
CertStoreEntry(cert, privatekey, path),
|
||||
spec
|
||||
spec.encode("idna")
|
||||
)
|
||||
|
||||
def add_cert(self, entry, *names):
|
||||
def add_cert(self, entry: CertStoreEntry, *names: bytes):
|
||||
"""
|
||||
Adds a cert to the certstore. We register the CN in the cert plus
|
||||
any SANs, and also the list of names provided as an argument.
|
||||
@ -271,21 +277,18 @@ class CertStore:
|
||||
self.certs[i] = entry
|
||||
|
||||
@staticmethod
|
||||
def asterisk_forms(dn):
|
||||
if dn is None:
|
||||
return []
|
||||
def asterisk_forms(dn: bytes) -> typing.List[bytes]:
|
||||
"""
|
||||
Return all asterisk forms for a domain. For example, for www.example.com this will return
|
||||
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
|
||||
"""
|
||||
parts = dn.split(b".")
|
||||
parts.reverse()
|
||||
curr_dn = b""
|
||||
dn_forms = [b"*"]
|
||||
for part in parts[:-1]:
|
||||
curr_dn = b"." + part + curr_dn # .example.com
|
||||
dn_forms.append(b"*" + curr_dn) # *.example.com
|
||||
if parts[-1] != b"*":
|
||||
dn_forms.append(parts[-1] + curr_dn)
|
||||
return dn_forms
|
||||
ret = [dn]
|
||||
for i in range(1, len(parts)):
|
||||
ret.append(b"*." + b".".join(parts[i:]))
|
||||
return ret
|
||||
|
||||
def get_cert(self, commonname, sans):
|
||||
def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes]):
|
||||
"""
|
||||
Returns an (cert, privkey, cert_chain) tuple.
|
||||
|
||||
@ -295,9 +298,12 @@ class CertStore:
|
||||
sans: A list of Subject Alternate Names.
|
||||
"""
|
||||
|
||||
potential_keys = self.asterisk_forms(commonname)
|
||||
potential_keys = [] # type: typing.List[TCertId]
|
||||
if commonname:
|
||||
potential_keys.extend(self.asterisk_forms(commonname))
|
||||
for s in sans:
|
||||
potential_keys.extend(self.asterisk_forms(s))
|
||||
potential_keys.append(b"*")
|
||||
potential_keys.append((commonname, tuple(sans)))
|
||||
|
||||
name = next(
|
||||
|
@ -479,7 +479,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest):
|
||||
ssl = True
|
||||
ssloptions = pathod.SSLOptions(
|
||||
certs=[
|
||||
(b"*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
|
||||
("*", tutils.test_data.path("mitmproxy/data/no_common_name.pem"))
|
||||
]
|
||||
)
|
||||
|
||||
@ -1142,7 +1142,7 @@ class AddUpstreamCertsToClientChainMixin:
|
||||
ssloptions = pathod.SSLOptions(
|
||||
cn=b"example.mitmproxy.org",
|
||||
certs=[
|
||||
(b"example.mitmproxy.org", servercert)
|
||||
("example.mitmproxy.org", servercert)
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -102,7 +102,7 @@ class TestCertStore:
|
||||
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
|
||||
dcp = tmpdir.join("dc")
|
||||
dcp.write(dc[0].to_pem())
|
||||
ca1.add_cert_file(b"foo.com", str(dcp))
|
||||
ca1.add_cert_file("foo.com", str(dcp))
|
||||
|
||||
ret = ca1.get_cert(b"foo.com", [])
|
||||
assert ret[0].serial == dc[0].serial
|
||||
|
@ -57,7 +57,7 @@ class TestNotAfterConnect(tservers.DaemonTests):
|
||||
class TestCustomCert(tservers.DaemonTests):
|
||||
ssl = True
|
||||
ssloptions = dict(
|
||||
certs=[(b"*", tutils.test_data.path("pathod/data/testkey.pem"))],
|
||||
certs=[("*", tutils.test_data.path("pathod/data/testkey.pem"))],
|
||||
)
|
||||
|
||||
def test_connect(self):
|
||||
|
Loading…
Reference in New Issue
Block a user