mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 07:08:10 +00:00
Merge pull request #4374 from mhils/cryptography-certs
Use cryptography for certificate generation
This commit is contained in:
commit
7d67eefe29
@ -38,6 +38,7 @@ If you depend on these features, please raise your voice in
|
||||
### Full Changelog
|
||||
|
||||
* New Proxy Core based on sans-io pattern (@mhils)
|
||||
* Use pyca/cryptography to generate certificates, not pyOpenSSL (@mhils)
|
||||
* Remove the legacy protocol stack (@Kriechi)
|
||||
* Remove all deprecated pathod and pathoc tools and modules (@Kriechi)
|
||||
* --- TODO: add new PRs above this line ---
|
||||
|
@ -66,7 +66,7 @@ class NextLayer:
|
||||
pass
|
||||
else:
|
||||
if sni:
|
||||
hostnames.append(sni.decode("idna"))
|
||||
hostnames.append(sni)
|
||||
|
||||
if not hostnames:
|
||||
return False
|
||||
|
@ -1,7 +1,8 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple, TypedDict, Any
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, TypedDict, Any
|
||||
|
||||
from OpenSSL import SSL, crypto
|
||||
from OpenSSL import SSL
|
||||
from mitmproxy import certs, ctx, exceptions
|
||||
from mitmproxy.net import tls as net_tls
|
||||
from mitmproxy.options import CONF_BASENAME
|
||||
@ -115,7 +116,7 @@ class TlsConfig:
|
||||
client: context.Client = tls_start.context.client
|
||||
server: context.Server = tls_start.context.server
|
||||
|
||||
cert, key, chain_file = self.get_cert(tls_start.context)
|
||||
entry = self.get_cert(tls_start.context)
|
||||
|
||||
if not client.cipher_list and ctx.options.ciphers_client:
|
||||
client.cipher_list = ctx.options.ciphers_client.split(":")
|
||||
@ -126,9 +127,9 @@ class TlsConfig:
|
||||
min_version=net_tls.Version[ctx.options.tls_version_client_min],
|
||||
max_version=net_tls.Version[ctx.options.tls_version_client_max],
|
||||
cipher_list=cipher_list,
|
||||
cert=cert,
|
||||
key=key,
|
||||
chain_file=chain_file,
|
||||
cert=entry.cert,
|
||||
key=entry.privatekey,
|
||||
chain_file=entry.chain_file,
|
||||
request_client_cert=False,
|
||||
alpn_select_callback=alpn_select_callback,
|
||||
extra_chain_certs=server.certificate_list,
|
||||
@ -152,8 +153,7 @@ class TlsConfig:
|
||||
verify = net_tls.Verify.VERIFY_PEER
|
||||
|
||||
if server.sni is True:
|
||||
server.sni = client.sni or server.address[0].encode()
|
||||
sni = server.sni or None # make sure that false-y values are None
|
||||
server.sni = client.sni or server.address[0]
|
||||
|
||||
if not server.alpn_offers:
|
||||
if client.alpn_offers:
|
||||
@ -182,7 +182,7 @@ class TlsConfig:
|
||||
if os.path.isfile(client_certs):
|
||||
client_cert = client_certs
|
||||
else:
|
||||
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
|
||||
server_name: str = server.sni or server.address[0]
|
||||
p = os.path.join(client_certs, f"{server_name}.pem")
|
||||
if os.path.isfile(p):
|
||||
client_cert = p
|
||||
@ -192,7 +192,7 @@ class TlsConfig:
|
||||
max_version=net_tls.Version[ctx.options.tls_version_client_max],
|
||||
cipher_list=cipher_list,
|
||||
verify=verify,
|
||||
sni=sni,
|
||||
sni=server.sni,
|
||||
ca_path=ctx.options.ssl_verify_upstream_trusted_confdir,
|
||||
ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca,
|
||||
client_cert=client_cert,
|
||||
@ -200,7 +200,8 @@ class TlsConfig:
|
||||
)
|
||||
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||
tls_start.ssl_conn.set_tlsext_host_name(server.sni)
|
||||
if server.sni:
|
||||
tls_start.ssl_conn.set_tlsext_host_name(server.sni.encode())
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
|
||||
def running(self):
|
||||
@ -232,8 +233,8 @@ class TlsConfig:
|
||||
if len(parts) == 1:
|
||||
parts = ["*", parts[0]]
|
||||
|
||||
cert = os.path.expanduser(parts[1])
|
||||
if not os.path.exists(cert):
|
||||
cert = Path(parts[1]).expanduser()
|
||||
if not cert.exists():
|
||||
raise exceptions.OptionsError(f"Certificate file does not exist: {cert}")
|
||||
try:
|
||||
self.certstore.add_cert_file(
|
||||
@ -241,16 +242,16 @@ class TlsConfig:
|
||||
cert,
|
||||
passphrase=ctx.options.cert_passphrase.encode("utf8") if ctx.options.cert_passphrase else None,
|
||||
)
|
||||
except crypto.Error as e:
|
||||
raise exceptions.OptionsError(f"Invalid certificate format: {cert}") from e
|
||||
except ValueError as e:
|
||||
raise exceptions.OptionsError(f"Invalid certificate format for {cert}: {e}") from e
|
||||
|
||||
def get_cert(self, conn_context: context.Context) -> Tuple[certs.Cert, SSL.PKey, str]:
|
||||
def get_cert(self, conn_context: context.Context) -> certs.CertStoreEntry:
|
||||
"""
|
||||
This function determines the Common Name (CN), Subject Alternative Names (SANs) and Organization Name
|
||||
our certificate should have and then fetches a matching cert from the certstore.
|
||||
"""
|
||||
altnames: List[bytes] = []
|
||||
organization: Optional[bytes] = None
|
||||
altnames: List[str] = []
|
||||
organization: Optional[str] = None
|
||||
|
||||
# Use upstream certificate if available.
|
||||
if conn_context.server.certificate_list:
|
||||
@ -265,11 +266,11 @@ class TlsConfig:
|
||||
if conn_context.client.sni:
|
||||
altnames.append(conn_context.client.sni)
|
||||
elif conn_context.server.address:
|
||||
altnames.append(conn_context.server.address[0].encode("idna"))
|
||||
altnames.append(conn_context.server.address[0])
|
||||
|
||||
# As a last resort, add *something* so that we have a certificate to serve.
|
||||
if not altnames:
|
||||
altnames.append(b"mitmproxy")
|
||||
altnames.append("mitmproxy")
|
||||
|
||||
# only keep first occurrence of each hostname
|
||||
altnames = list(dict.fromkeys(altnames))
|
||||
|
@ -1,23 +1,24 @@
|
||||
import os
|
||||
import ssl
|
||||
import time
|
||||
import contextlib
|
||||
import datetime
|
||||
import ipaddress
|
||||
import os
|
||||
import sys
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional, Union, Dict, List
|
||||
from typing import Tuple, Optional, Union, Dict, List, NewType
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec
|
||||
from cryptography.hazmat.primitives.serialization import pkcs12
|
||||
from cryptography.x509 import NameOID, ExtendedKeyUsageOID
|
||||
|
||||
from pyasn1.type import univ, constraint, char, namedtype, tag
|
||||
from pyasn1.codec.der.decoder import decode
|
||||
from pyasn1.error import PyAsn1Error
|
||||
import OpenSSL
|
||||
|
||||
from mitmproxy.coretypes import serializable
|
||||
|
||||
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
|
||||
DEFAULT_EXP = 94608000 # = 60 * 60 * 24 * 365 * 3 = 3 years
|
||||
DEFAULT_EXP_DUMMY_CERT = 31536000 # = 60 * 60 * 24 * 365 = 1 year
|
||||
CA_EXPIRY = datetime.timedelta(days=3 * 365)
|
||||
CERT_EXPIRY = datetime.timedelta(days=365)
|
||||
|
||||
# Generated with "openssl dhparam". It's too slow to generate this on startup.
|
||||
DEFAULT_DHPARAM = b"""
|
||||
@ -37,51 +38,166 @@ rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI=
|
||||
"""
|
||||
|
||||
|
||||
def create_ca(organization: str, cn: str, exp: int, key_size: int) -> Tuple[OpenSSL.crypto.PKey, OpenSSL.crypto.X509]:
|
||||
key = OpenSSL.crypto.PKey()
|
||||
key.generate_key(OpenSSL.crypto.TYPE_RSA, key_size)
|
||||
cert = OpenSSL.crypto.X509()
|
||||
cert.set_serial_number(int(time.time() * 10000))
|
||||
cert.set_version(2)
|
||||
cert.get_subject().CN = cn
|
||||
cert.get_subject().O = organization
|
||||
cert.gmtime_adj_notBefore(-3600 * 48)
|
||||
cert.gmtime_adj_notAfter(exp)
|
||||
cert.set_issuer(cert.get_subject())
|
||||
cert.set_pubkey(key)
|
||||
cert.add_extensions([
|
||||
OpenSSL.crypto.X509Extension(
|
||||
b"basicConstraints",
|
||||
True,
|
||||
b"CA:TRUE"
|
||||
),
|
||||
OpenSSL.crypto.X509Extension(
|
||||
b"nsCertType",
|
||||
False,
|
||||
b"sslCA"
|
||||
),
|
||||
OpenSSL.crypto.X509Extension(
|
||||
b"extendedKeyUsage",
|
||||
False,
|
||||
b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC"
|
||||
),
|
||||
OpenSSL.crypto.X509Extension(
|
||||
b"keyUsage",
|
||||
True,
|
||||
b"keyCertSign, cRLSign"
|
||||
),
|
||||
OpenSSL.crypto.X509Extension(
|
||||
b"subjectKeyIdentifier",
|
||||
False,
|
||||
b"hash",
|
||||
subject=cert
|
||||
),
|
||||
class Cert(serializable.Serializable):
|
||||
_cert: x509.Certificate
|
||||
|
||||
def __init__(self, cert: x509.Certificate):
|
||||
assert isinstance(cert, x509.Certificate)
|
||||
self._cert = cert
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.fingerprint() == other.fingerprint()
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
return cls.from_pem(state)
|
||||
|
||||
def get_state(self):
|
||||
return self.to_pem()
|
||||
|
||||
def set_state(self, state):
|
||||
self._cert = x509.load_pem_x509_certificate(state)
|
||||
|
||||
@classmethod
|
||||
def from_pem(cls, data: bytes) -> "Cert":
|
||||
cert = x509.load_pem_x509_certificate(data) # type: ignore
|
||||
return cls(cert)
|
||||
|
||||
def to_pem(self) -> bytes:
|
||||
return self._cert.public_bytes(serialization.Encoding.PEM)
|
||||
|
||||
@classmethod
|
||||
def from_pyopenssl(self, x509: OpenSSL.crypto.X509) -> "Cert":
|
||||
return Cert(x509.to_cryptography())
|
||||
|
||||
def to_pyopenssl(self) -> OpenSSL.crypto.X509:
|
||||
return OpenSSL.crypto.X509.from_cryptography(self._cert)
|
||||
|
||||
def fingerprint(self) -> bytes:
|
||||
return self._cert.fingerprint(hashes.SHA256())
|
||||
|
||||
@property
|
||||
def issuer(self) -> List[Tuple[str, str]]:
|
||||
return _name_to_keyval(self._cert.issuer)
|
||||
|
||||
@property
|
||||
def notbefore(self) -> datetime.datetime:
|
||||
return self._cert.not_valid_before
|
||||
|
||||
@property
|
||||
def notafter(self) -> datetime.datetime:
|
||||
return self._cert.not_valid_after
|
||||
|
||||
def has_expired(self) -> bool:
|
||||
return datetime.datetime.utcnow() > self._cert.not_valid_after
|
||||
|
||||
@property
|
||||
def subject(self) -> List[Tuple[str, str]]:
|
||||
return _name_to_keyval(self._cert.subject)
|
||||
|
||||
@property
|
||||
def serial(self) -> int:
|
||||
return self._cert.serial_number
|
||||
|
||||
@property
|
||||
def keyinfo(self):
|
||||
public_key = self._cert.public_key()
|
||||
if isinstance(public_key, rsa.RSAPublicKey):
|
||||
return "RSA", public_key.key_size
|
||||
if isinstance(public_key, dsa.DSAPublicKey):
|
||||
return "DSA", public_key.key_size
|
||||
if isinstance(public_key, ec.EllipticCurvePublicKey):
|
||||
return f"EC ({public_key.curve.name})", public_key.key_size
|
||||
return (public_key.__class__.__name__.replace("PublicKey", "").replace("_", ""),
|
||||
getattr(public_key, "key_size", -1)) # pragma: no cover
|
||||
|
||||
@property
|
||||
def cn(self) -> Optional[str]:
|
||||
attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
|
||||
if attrs:
|
||||
return attrs[0].value
|
||||
return None
|
||||
|
||||
@property
|
||||
def organization(self) -> Optional[str]:
|
||||
attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.ORGANIZATION_NAME)
|
||||
if attrs:
|
||||
return attrs[0].value
|
||||
return None
|
||||
|
||||
@property
|
||||
def altnames(self) -> List[str]:
|
||||
"""
|
||||
Get all SubjectAlternativeName DNS altnames.
|
||||
"""
|
||||
try:
|
||||
ext = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
|
||||
except x509.ExtensionNotFound:
|
||||
return []
|
||||
else:
|
||||
return (
|
||||
ext.get_values_for_type(x509.DNSName)
|
||||
+
|
||||
[str(x) for x in ext.get_values_for_type(x509.IPAddress)]
|
||||
)
|
||||
|
||||
|
||||
def _name_to_keyval(name: x509.Name) -> List[Tuple[str, str]]:
|
||||
parts = []
|
||||
for rdn in name.rdns:
|
||||
k, v = rdn.rfc4514_string().split("=", maxsplit=1)
|
||||
parts.append((k, v))
|
||||
return parts
|
||||
|
||||
|
||||
def create_ca(
|
||||
organization: str,
|
||||
cn: str,
|
||||
key_size: int,
|
||||
) -> Tuple[rsa.RSAPrivateKeyWithSerialization, x509.Certificate]:
|
||||
now = datetime.datetime.now()
|
||||
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=key_size,
|
||||
) # type: ignore
|
||||
name = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, cn),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization)
|
||||
])
|
||||
cert.sign(key, "sha256")
|
||||
return key, cert
|
||||
builder = x509.CertificateBuilder()
|
||||
builder = builder.serial_number(x509.random_serial_number())
|
||||
builder = builder.subject_name(name)
|
||||
builder = builder.not_valid_before(now - datetime.timedelta(days=2))
|
||||
builder = builder.not_valid_after(now + CA_EXPIRY)
|
||||
builder = builder.issuer_name(name)
|
||||
builder = builder.public_key(private_key.public_key())
|
||||
builder = builder.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
|
||||
builder = builder.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False)
|
||||
builder = builder.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=False,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
), critical=True)
|
||||
builder = builder.add_extension(x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), critical=False)
|
||||
cert = builder.sign(private_key=private_key, algorithm=hashes.SHA256()) # type: ignore
|
||||
return private_key, cert
|
||||
|
||||
|
||||
def dummy_cert(privkey, cacert, commonname, sans, organization):
|
||||
def dummy_cert(
|
||||
privkey: rsa.RSAPrivateKey,
|
||||
cacert: x509.Certificate,
|
||||
commonname: Optional[str],
|
||||
sans: List[str],
|
||||
organization: Optional[str] = None,
|
||||
) -> Cert:
|
||||
"""
|
||||
Generates a dummy certificate.
|
||||
|
||||
@ -93,111 +209,114 @@ def dummy_cert(privkey, cacert, commonname, sans, organization):
|
||||
|
||||
Returns cert if operation succeeded, None if not.
|
||||
"""
|
||||
ss = []
|
||||
for i in sans:
|
||||
try:
|
||||
ipaddress.ip_address(i.decode("ascii"))
|
||||
except ValueError:
|
||||
ss.append(b"DNS:%s" % i)
|
||||
else:
|
||||
ss.append(b"IP:%s" % i)
|
||||
ss = b", ".join(ss)
|
||||
builder = x509.CertificateBuilder()
|
||||
builder = builder.issuer_name(cacert.subject)
|
||||
builder = builder.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False)
|
||||
builder = builder.public_key(cacert.public_key())
|
||||
|
||||
cert = OpenSSL.crypto.X509()
|
||||
cert.gmtime_adj_notBefore(-3600 * 48)
|
||||
cert.gmtime_adj_notAfter(DEFAULT_EXP_DUMMY_CERT)
|
||||
cert.set_issuer(cacert.get_subject())
|
||||
now = datetime.datetime.now()
|
||||
builder = builder.not_valid_before(now - datetime.timedelta(days=2))
|
||||
builder = builder.not_valid_after(now + CERT_EXPIRY)
|
||||
|
||||
subject = []
|
||||
is_valid_commonname = (
|
||||
commonname is not None and len(commonname) < 64
|
||||
commonname is not None and len(commonname) < 64
|
||||
)
|
||||
if is_valid_commonname:
|
||||
cert.get_subject().CN = commonname
|
||||
assert commonname is not None
|
||||
subject.append(x509.NameAttribute(NameOID.COMMON_NAME, commonname))
|
||||
if organization is not None:
|
||||
cert.get_subject().O = organization
|
||||
cert.set_serial_number(int(time.time() * 10000))
|
||||
if ss:
|
||||
cert.set_version(2)
|
||||
cert.add_extensions(
|
||||
[OpenSSL.crypto.X509Extension(
|
||||
b"subjectAltName",
|
||||
# RFC 5280 §4.2.1.6: subjectAltName is critical if subject is empty.
|
||||
not is_valid_commonname,
|
||||
ss
|
||||
)]
|
||||
)
|
||||
cert.add_extensions([
|
||||
OpenSSL.crypto.X509Extension(
|
||||
b"extendedKeyUsage",
|
||||
False,
|
||||
b"serverAuth,clientAuth"
|
||||
)
|
||||
])
|
||||
cert.set_pubkey(cacert.get_pubkey())
|
||||
cert.sign(privkey, "sha256")
|
||||
assert organization is not None
|
||||
subject.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization))
|
||||
builder = builder.subject_name(x509.Name(subject))
|
||||
builder = builder.serial_number(x509.random_serial_number())
|
||||
|
||||
ss: List[x509.GeneralName] = []
|
||||
for x in sans:
|
||||
try:
|
||||
ip = ipaddress.ip_address(x)
|
||||
except ValueError:
|
||||
ss.append(x509.DNSName(x))
|
||||
else:
|
||||
ss.append(x509.IPAddress(ip))
|
||||
# RFC 5280 §4.2.1.6: subjectAltName is critical if subject is empty.
|
||||
builder = builder.add_extension(x509.SubjectAlternativeName(ss), critical=not is_valid_commonname)
|
||||
cert = builder.sign(private_key=privkey, algorithm=hashes.SHA256()) # type: ignore
|
||||
return Cert(cert)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CertStoreEntry:
|
||||
|
||||
def __init__(self, cert, privatekey, chain_file):
|
||||
self.cert = cert
|
||||
self.privatekey = privatekey
|
||||
self.chain_file = chain_file
|
||||
cert: Cert
|
||||
privatekey: rsa.RSAPrivateKey
|
||||
chain_file: Optional[Path]
|
||||
|
||||
|
||||
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
|
||||
TGeneratedCertId = Tuple[Optional[bytes], Tuple[bytes, ...]] # (common_name, sans)
|
||||
TCustomCertId = str # manually provided certs (e.g. mitmproxy's --certs)
|
||||
TGeneratedCertId = Tuple[Optional[str], Tuple[str, ...]] # (common_name, sans)
|
||||
TCertId = Union[TCustomCertId, TGeneratedCertId]
|
||||
|
||||
DHParams = NewType("DHParams", bytes)
|
||||
|
||||
|
||||
class CertStore:
|
||||
|
||||
"""
|
||||
Implements an in-memory certificate store.
|
||||
"""
|
||||
STORE_CAP = 100
|
||||
certs: Dict[TCertId, CertStoreEntry]
|
||||
expire_queue: List[CertStoreEntry]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_privatekey,
|
||||
default_ca,
|
||||
default_chain_file,
|
||||
dhparams):
|
||||
default_privatekey: rsa.RSAPrivateKey,
|
||||
default_ca: Cert,
|
||||
default_chain_file: Optional[Path],
|
||||
dhparams: DHParams
|
||||
):
|
||||
self.default_privatekey = default_privatekey
|
||||
self.default_ca = default_ca
|
||||
self.default_chain_file = default_chain_file
|
||||
self.dhparams = dhparams
|
||||
self.certs: Dict[TCertId, CertStoreEntry] = {}
|
||||
self.certs = {}
|
||||
self.expire_queue = []
|
||||
|
||||
def expire(self, entry):
|
||||
def expire(self, entry: CertStoreEntry) -> None:
|
||||
self.expire_queue.append(entry)
|
||||
if len(self.expire_queue) > self.STORE_CAP:
|
||||
d = self.expire_queue.pop(0)
|
||||
self.certs = {k: v for k, v in self.certs.items() if v != d}
|
||||
|
||||
@staticmethod
|
||||
def load_dhparam(path: str):
|
||||
|
||||
def load_dhparam(path: Path) -> DHParams:
|
||||
# mitmproxy<=0.10 doesn't generate a dhparam file.
|
||||
# Create it now if necessary.
|
||||
if not os.path.exists(path):
|
||||
with open(path, "wb") as f:
|
||||
f.write(DEFAULT_DHPARAM)
|
||||
if not path.exists():
|
||||
path.write_bytes(DEFAULT_DHPARAM)
|
||||
|
||||
bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r")
|
||||
# we could use cryptography for this, but it's unclear how to convert cryptography's object to pyOpenSSL's
|
||||
# expected format.
|
||||
bio = OpenSSL.SSL._lib.BIO_new_file(str(path).encode(sys.getfilesystemencoding()), b"r")
|
||||
if bio != OpenSSL.SSL._ffi.NULL:
|
||||
bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free)
|
||||
dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams(
|
||||
bio,
|
||||
OpenSSL.SSL._ffi.NULL,
|
||||
OpenSSL.SSL._ffi.NULL,
|
||||
OpenSSL.SSL._ffi.NULL)
|
||||
OpenSSL.SSL._ffi.NULL
|
||||
)
|
||||
dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free)
|
||||
return dh
|
||||
raise RuntimeError("Error loading DH Params.") # pragma: no cover
|
||||
|
||||
@classmethod
|
||||
def from_store(cls, path: Union[Path, str], basename: str, key_size, passphrase: Optional[bytes] = None) -> "CertStore":
|
||||
def from_store(
|
||||
cls,
|
||||
path: Union[Path, str],
|
||||
basename: str,
|
||||
key_size: int,
|
||||
passphrase: Optional[bytes] = None
|
||||
) -> "CertStore":
|
||||
path = Path(path)
|
||||
ca_file = path / f"{basename}-ca.pem"
|
||||
dhparam_file = path / f"{basename}-dhparam.pem"
|
||||
@ -208,17 +327,14 @@ class CertStore:
|
||||
@classmethod
|
||||
def from_files(cls, ca_file: Path, dhparam_file: Path, passphrase: Optional[bytes] = None) -> "CertStore":
|
||||
raw = ca_file.read_bytes()
|
||||
ca = OpenSSL.crypto.load_certificate(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
raw
|
||||
)
|
||||
key = OpenSSL.crypto.load_privatekey(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
raw,
|
||||
passphrase
|
||||
)
|
||||
dh = cls.load_dhparam(str(dhparam_file))
|
||||
return cls(key, ca, str(ca_file), dh)
|
||||
key = load_pem_private_key(raw, passphrase)
|
||||
ca = Cert.from_pem(raw)
|
||||
dh = cls.load_dhparam(dhparam_file)
|
||||
if raw.count(b"BEGIN CERTIFICATE") != 1:
|
||||
chain_file: Optional[Path] = ca_file
|
||||
else:
|
||||
chain_file = None
|
||||
return cls(key, ca, chain_file, dh)
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
@ -236,74 +352,72 @@ class CertStore:
|
||||
os.umask(original_umask)
|
||||
|
||||
@staticmethod
|
||||
def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None, expiry=DEFAULT_EXP) -> None:
|
||||
def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None) -> None:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
organization = organization or basename
|
||||
cn = cn or basename
|
||||
|
||||
key, ca = create_ca(organization=organization, cn=cn, exp=expiry, key_size=key_size)
|
||||
# Dump the CA plus private key
|
||||
with CertStore.umask_secret(), (path / f"{basename}-ca.pem").open("wb") as f:
|
||||
f.write(
|
||||
OpenSSL.crypto.dump_privatekey(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
key))
|
||||
f.write(
|
||||
OpenSSL.crypto.dump_certificate(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
ca))
|
||||
key: rsa.RSAPrivateKeyWithSerialization
|
||||
ca: x509.Certificate
|
||||
key, ca = create_ca(organization=organization, cn=cn, key_size=key_size)
|
||||
|
||||
# Dump the CA plus private key.
|
||||
with CertStore.umask_secret():
|
||||
# PEM format
|
||||
(path / f"{basename}-ca.pem").write_bytes(
|
||||
key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
) +
|
||||
ca.public_bytes(serialization.Encoding.PEM)
|
||||
)
|
||||
|
||||
# PKCS12 format for Windows devices
|
||||
(path / f"{basename}-ca.p12").write_bytes(
|
||||
pkcs12.serialize_key_and_certificates( # type: ignore
|
||||
name=basename.encode(),
|
||||
key=key,
|
||||
cert=ca,
|
||||
cas=None,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
|
||||
# Dump the certificate in PEM format
|
||||
with (path / f"{basename}-ca-cert.pem").open("wb") as f:
|
||||
f.write(
|
||||
OpenSSL.crypto.dump_certificate(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
ca))
|
||||
|
||||
pem_cert = ca.public_bytes(serialization.Encoding.PEM)
|
||||
(path / f"{basename}-ca-cert.pem").write_bytes(pem_cert)
|
||||
# Create a .cer file with the same contents for Android
|
||||
with (path / f"{basename}-ca-cert.cer").open("wb") as f:
|
||||
f.write(
|
||||
OpenSSL.crypto.dump_certificate(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
ca))
|
||||
(path / f"{basename}-ca-cert.cer").write_bytes(pem_cert)
|
||||
|
||||
# Dump the certificate in PKCS12 format for Windows devices
|
||||
with (path / f"{basename}-ca-cert.p12").open("wb") as f:
|
||||
p12 = OpenSSL.crypto.PKCS12()
|
||||
p12.set_certificate(ca)
|
||||
f.write(p12.export())
|
||||
|
||||
# Dump the certificate and key in a PKCS12 format for Windows devices
|
||||
with CertStore.umask_secret(), (path / f"{basename}-ca.p12").open("wb") as f:
|
||||
p12 = OpenSSL.crypto.PKCS12()
|
||||
p12.set_certificate(ca)
|
||||
p12.set_privatekey(key)
|
||||
f.write(p12.export())
|
||||
|
||||
with (path / f"{basename}-dhparam.pem").open("wb") as f:
|
||||
f.write(DEFAULT_DHPARAM)
|
||||
|
||||
def add_cert_file(self, spec: str, path: str, passphrase: Optional[bytes] = None) -> None:
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read()
|
||||
cert = Cert(
|
||||
OpenSSL.crypto.load_certificate(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
raw))
|
||||
try:
|
||||
privatekey = OpenSSL.crypto.load_privatekey(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
raw,
|
||||
passphrase)
|
||||
except Exception:
|
||||
privatekey = self.default_privatekey
|
||||
self.add_cert(
|
||||
CertStoreEntry(cert, privatekey, path),
|
||||
spec.encode("idna")
|
||||
(path / f"{basename}-ca-cert.p12").write_bytes(
|
||||
pkcs12.serialize_key_and_certificates( # type: ignore
|
||||
name=basename.encode(),
|
||||
key=None,
|
||||
cert=ca,
|
||||
cas=None,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
|
||||
def add_cert(self, entry: CertStoreEntry, *names: bytes):
|
||||
(path / f"{basename}-dhparam.pem").write_bytes(DEFAULT_DHPARAM)
|
||||
|
||||
def add_cert_file(self, spec: str, path: Path, passphrase: Optional[bytes] = None) -> None:
|
||||
raw = path.read_bytes()
|
||||
cert = Cert.from_pem(raw)
|
||||
try:
|
||||
key = load_pem_private_key(raw, password=passphrase)
|
||||
except ValueError:
|
||||
key = self.default_privatekey
|
||||
|
||||
self.add_cert(
|
||||
CertStoreEntry(cert, key, path),
|
||||
spec
|
||||
)
|
||||
|
||||
def add_cert(self, entry: CertStoreEntry, *names: str) -> None:
|
||||
"""
|
||||
Adds a cert to the certstore. We register the CN in the cert plus
|
||||
any SANs, and also the list of names provided as an argument.
|
||||
@ -316,26 +430,24 @@ class CertStore:
|
||||
self.certs[i] = entry
|
||||
|
||||
@staticmethod
|
||||
def asterisk_forms(dn: bytes) -> List[bytes]:
|
||||
def asterisk_forms(dn: str) -> List[str]:
|
||||
"""
|
||||
Return all asterisk forms for a domain. For example, for www.example.com this will return
|
||||
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
|
||||
"""
|
||||
parts = dn.split(b".")
|
||||
parts = dn.split(".")
|
||||
ret = [dn]
|
||||
for i in range(1, len(parts)):
|
||||
ret.append(b"*." + b".".join(parts[i:]))
|
||||
ret.append("*." + ".".join(parts[i:]))
|
||||
return ret
|
||||
|
||||
def get_cert(
|
||||
self,
|
||||
commonname: Optional[bytes],
|
||||
sans: List[bytes],
|
||||
organization: Optional[bytes] = None
|
||||
) -> Tuple["Cert", OpenSSL.SSL.PKey, str]:
|
||||
commonname: Optional[str],
|
||||
sans: List[str],
|
||||
organization: Optional[str] = None
|
||||
) -> CertStoreEntry:
|
||||
"""
|
||||
Returns an (cert, privkey, cert_chain) tuple.
|
||||
|
||||
commonname: Common name for the generated certificate. Must be a
|
||||
valid, plain-ASCII, IDNA-encoded domain name.
|
||||
|
||||
@ -349,7 +461,7 @@ class CertStore:
|
||||
potential_keys.extend(self.asterisk_forms(commonname))
|
||||
for s in sans:
|
||||
potential_keys.extend(self.asterisk_forms(s))
|
||||
potential_keys.append(b"*")
|
||||
potential_keys.append("*")
|
||||
potential_keys.append((commonname, tuple(sans)))
|
||||
|
||||
name = next(
|
||||
@ -362,147 +474,27 @@ class CertStore:
|
||||
entry = CertStoreEntry(
|
||||
cert=dummy_cert(
|
||||
self.default_privatekey,
|
||||
self.default_ca,
|
||||
self.default_ca._cert,
|
||||
commonname,
|
||||
sans,
|
||||
organization),
|
||||
privatekey=self.default_privatekey,
|
||||
chain_file=self.default_chain_file)
|
||||
chain_file=self.default_chain_file
|
||||
)
|
||||
self.certs[(commonname, tuple(sans))] = entry
|
||||
self.expire(entry)
|
||||
|
||||
return entry.cert, entry.privatekey, entry.chain_file
|
||||
return entry
|
||||
|
||||
|
||||
class _GeneralName(univ.Choice):
|
||||
# We only care about dNSName and iPAddress
|
||||
componentType = namedtype.NamedTypes(
|
||||
namedtype.NamedType('dNSName', char.IA5String().subtype(
|
||||
implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2)
|
||||
)),
|
||||
namedtype.NamedType('iPAddress', univ.OctetString().subtype(
|
||||
implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 7)
|
||||
)),
|
||||
)
|
||||
|
||||
|
||||
class _GeneralNames(univ.SequenceOf):
|
||||
componentType = _GeneralName()
|
||||
sizeSpec = univ.SequenceOf.sizeSpec + \
|
||||
constraint.ValueSizeConstraint(1, 1024)
|
||||
|
||||
|
||||
class Cert(serializable.Serializable):
|
||||
|
||||
def __init__(self, cert):
|
||||
"""
|
||||
Returns a (common name, [subject alternative names]) tuple.
|
||||
"""
|
||||
self.x509 = cert
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.digest("sha256") == other.digest("sha256")
|
||||
|
||||
def get_state(self):
|
||||
return self.to_pem()
|
||||
|
||||
def set_state(self, state):
|
||||
self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
return cls.from_pem(state)
|
||||
|
||||
@classmethod
|
||||
def from_pem(cls, txt):
|
||||
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
|
||||
return cls(x509)
|
||||
|
||||
@classmethod
|
||||
def from_der(cls, der):
|
||||
pem = ssl.DER_cert_to_PEM_cert(der)
|
||||
return cls.from_pem(pem)
|
||||
|
||||
def to_pem(self):
|
||||
return OpenSSL.crypto.dump_certificate(
|
||||
OpenSSL.crypto.FILETYPE_PEM,
|
||||
self.x509)
|
||||
|
||||
def digest(self, name):
|
||||
return self.x509.digest(name)
|
||||
|
||||
@property
|
||||
def issuer(self):
|
||||
return self.x509.get_issuer().get_components()
|
||||
|
||||
@property
|
||||
def notbefore(self):
|
||||
t = self.x509.get_notBefore()
|
||||
return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ")
|
||||
|
||||
@property
|
||||
def notafter(self):
|
||||
t = self.x509.get_notAfter()
|
||||
return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ")
|
||||
|
||||
@property
|
||||
def has_expired(self):
|
||||
return self.x509.has_expired()
|
||||
|
||||
@property
|
||||
def subject(self):
|
||||
return self.x509.get_subject().get_components()
|
||||
|
||||
@property
|
||||
def serial(self):
|
||||
return self.x509.get_serial_number()
|
||||
|
||||
@property
|
||||
def keyinfo(self):
|
||||
pk = self.x509.get_pubkey()
|
||||
types = {
|
||||
OpenSSL.crypto.TYPE_RSA: "RSA",
|
||||
OpenSSL.crypto.TYPE_DSA: "DSA",
|
||||
}
|
||||
return (
|
||||
types.get(pk.type(), "UNKNOWN"),
|
||||
pk.bits()
|
||||
)
|
||||
|
||||
@property
|
||||
def cn(self) -> Optional[bytes]:
|
||||
c = None
|
||||
for i in self.subject:
|
||||
if i[0] == b"CN":
|
||||
c = i[1]
|
||||
return c
|
||||
|
||||
@property
|
||||
def organization(self) -> Optional[bytes]:
|
||||
c = None
|
||||
for i in self.subject:
|
||||
if i[0] == b"O":
|
||||
c = i[1]
|
||||
return c
|
||||
|
||||
@property
|
||||
def altnames(self) -> List[bytes]:
|
||||
"""
|
||||
Returns:
|
||||
All DNS altnames.
|
||||
"""
|
||||
# tcp.TCPClient.convert_to_tls assumes that this property only contains DNS altnames for hostname verification.
|
||||
altnames = []
|
||||
for i in range(self.x509.get_extension_count()):
|
||||
ext = self.x509.get_extension(i)
|
||||
if ext.get_short_name() == b"subjectAltName":
|
||||
try:
|
||||
dec = decode(ext.get_data(), asn1Spec=_GeneralNames())
|
||||
except PyAsn1Error:
|
||||
continue
|
||||
for x in dec[0]:
|
||||
if x[0].hasValue():
|
||||
e = x[0].asOctets()
|
||||
altnames.append(e)
|
||||
|
||||
return altnames
|
||||
def load_pem_private_key(data: bytes, password: Optional[bytes]) -> rsa.RSAPrivateKey:
|
||||
"""
|
||||
like cryptography's load_pem_private_key, but silently falls back to not using a password
|
||||
if the private key is unencrypted.
|
||||
"""
|
||||
try:
|
||||
return serialization.load_pem_private_key(data, password) # type: ignore
|
||||
except TypeError:
|
||||
if password is not None:
|
||||
return load_pem_private_key(data, None)
|
||||
raise
|
||||
|
@ -230,6 +230,23 @@ def convert_9_10(data):
|
||||
return data
|
||||
|
||||
|
||||
def convert_10_11(data):
|
||||
data["version"] = 11
|
||||
|
||||
def conv_conn(conn):
|
||||
conn["sni"] = strutils.always_str(conn["sni"], "ascii", "backslashreplace")
|
||||
conn["alpn"] = conn.pop("alpn_proto_negotiated")
|
||||
conn["alpn_offers"] = conn["alpn_offers"] or []
|
||||
conn["cipher_list"] = conn["cipher_list"] or []
|
||||
|
||||
conv_conn(data["client_conn"])
|
||||
conv_conn(data["server_conn"])
|
||||
if data["server_conn"]["via"]:
|
||||
conv_conn(data["server_conn"]["via"])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _convert_dict_keys(o: Any) -> Any:
|
||||
if isinstance(o, dict):
|
||||
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
|
||||
@ -287,6 +304,7 @@ converters = {
|
||||
7: convert_7_8,
|
||||
8: convert_8_9,
|
||||
9: convert_9_10,
|
||||
10: convert_10_11,
|
||||
}
|
||||
|
||||
|
||||
|
@ -6,9 +6,10 @@ from pathlib import Path
|
||||
from typing import Iterable, Callable, Optional, Tuple, List, Any, BinaryIO
|
||||
|
||||
import certifi
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from kaitaistruct import KaitaiStream
|
||||
|
||||
from OpenSSL import SSL
|
||||
from OpenSSL import SSL, crypto
|
||||
from mitmproxy import certs
|
||||
from mitmproxy.contrib.kaitaistruct import tls_client_hello
|
||||
from mitmproxy.net import check
|
||||
@ -129,7 +130,7 @@ def create_proxy_server_context(
|
||||
max_version: Version,
|
||||
cipher_list: Optional[Iterable[str]],
|
||||
verify: Verify,
|
||||
sni: Optional[bytes],
|
||||
sni: Optional[str],
|
||||
ca_path: Optional[str],
|
||||
ca_pemfile: Optional[str],
|
||||
client_cert: Optional[str],
|
||||
@ -147,6 +148,7 @@ def create_proxy_server_context(
|
||||
|
||||
context.set_verify(verify.value, None)
|
||||
if sni is not None:
|
||||
assert isinstance(sni, str)
|
||||
# Manually enable hostname verification on the context object.
|
||||
# https://wiki.openssl.org/index.php/Hostname_validation
|
||||
param = SSL._lib.SSL_CTX_get0_param(context._context)
|
||||
@ -157,7 +159,7 @@ def create_proxy_server_context(
|
||||
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
|
||||
)
|
||||
SSL._openssl_assert(
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni, 0) == 1
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni.encode(), 0) == 1
|
||||
)
|
||||
|
||||
if ca_path is None and ca_pemfile is None:
|
||||
@ -188,12 +190,12 @@ def create_client_proxy_context(
|
||||
max_version: Version,
|
||||
cipher_list: Optional[Iterable[str]],
|
||||
cert: certs.Cert,
|
||||
key: SSL.PKey,
|
||||
chain_file: str,
|
||||
key: rsa.RSAPrivateKey,
|
||||
chain_file: Optional[Path],
|
||||
alpn_select_callback: Optional[Callable[[SSL.Connection, List[bytes]], Any]],
|
||||
request_client_cert: bool,
|
||||
extra_chain_certs: Iterable[certs.Cert],
|
||||
dhparams,
|
||||
dhparams: certs.DHParams,
|
||||
) -> SSL.Context:
|
||||
context: SSL.Context = _create_ssl_context(
|
||||
method=Method.TLS_SERVER_METHOD,
|
||||
@ -202,12 +204,13 @@ def create_client_proxy_context(
|
||||
cipher_list=cipher_list,
|
||||
)
|
||||
|
||||
context.use_certificate(cert.x509)
|
||||
context.use_privatekey(key)
|
||||
try:
|
||||
context.load_verify_locations(chain_file, None)
|
||||
except SSL.Error as e:
|
||||
raise RuntimeError(f"Cannot load certificate chain ({chain_file}).") from e
|
||||
context.use_certificate(cert.to_pyopenssl())
|
||||
context.use_privatekey(crypto.PKey.from_cryptography_key(key))
|
||||
if chain_file is not None:
|
||||
try:
|
||||
context.load_verify_locations(str(chain_file), None)
|
||||
except SSL.Error as e:
|
||||
raise RuntimeError(f"Cannot load certificate chain ({chain_file}).") from e
|
||||
|
||||
if alpn_select_callback is not None:
|
||||
assert callable(alpn_select_callback)
|
||||
@ -227,7 +230,7 @@ def create_client_proxy_context(
|
||||
context.set_verify(Verify.VERIFY_NONE.value, None)
|
||||
|
||||
for i in extra_chain_certs:
|
||||
context.add_extra_chain_cert(i.x509)
|
||||
context.add_extra_chain_cert(i._cert)
|
||||
|
||||
if dhparams:
|
||||
SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams)
|
||||
@ -277,7 +280,7 @@ class ClientHello:
|
||||
return self._client_hello.cipher_suites.cipher_suites
|
||||
|
||||
@property
|
||||
def sni(self) -> Optional[bytes]:
|
||||
def sni(self) -> Optional[str]:
|
||||
if self._client_hello.extensions:
|
||||
for extension in self._client_hello.extensions.extensions:
|
||||
is_valid_sni_extension = (
|
||||
@ -287,11 +290,11 @@ class ClientHello:
|
||||
check.is_valid_host(extension.body.server_names[0].host_name)
|
||||
)
|
||||
if is_valid_sni_extension:
|
||||
return extension.body.server_names[0].host_name
|
||||
return extension.body.server_names[0].host_name.decode("ascii")
|
||||
return None
|
||||
|
||||
@property
|
||||
def alpn_protocols(self):
|
||||
def alpn_protocols(self) -> List[bytes]:
|
||||
if self._client_hello.extensions:
|
||||
for extension in self._client_hello.extensions.extensions:
|
||||
if extension.type == 0x10:
|
||||
|
@ -65,7 +65,7 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
|
||||
cipher_list: Sequence[str] = ()
|
||||
"""Ciphers accepted by the proxy server on this connection."""
|
||||
tls_version: Optional[str] = None
|
||||
sni: Union[bytes, Literal[True], None]
|
||||
sni: Union[str, Literal[True], None]
|
||||
|
||||
timestamp_end: Optional[float] = None
|
||||
"""Connection end timestamp"""
|
||||
@ -109,7 +109,8 @@ class Client(Connection):
|
||||
timestamp_start: float
|
||||
"""TCP SYN received"""
|
||||
|
||||
sni: Union[bytes, None] = None
|
||||
mitmcert: Optional[certs.Cert] = None
|
||||
sni: Union[str, None] = None
|
||||
|
||||
def __init__(self, peername, sockname, timestamp_start):
|
||||
self.id = str(uuid.uuid4())
|
||||
@ -122,10 +123,10 @@ class Client(Connection):
|
||||
# This means we need to add all new fields to the old implementation.
|
||||
return {
|
||||
'address': self.peername,
|
||||
'alpn_proto_negotiated': self.alpn,
|
||||
'alpn': self.alpn,
|
||||
'cipher_name': self.cipher,
|
||||
'id': self.id,
|
||||
'mitmcert': None,
|
||||
'mitmcert': self.mitmcert.get_state() if self.mitmcert is not None else None,
|
||||
'sni': self.sni,
|
||||
'timestamp_end': self.timestamp_end,
|
||||
'timestamp_start': self.timestamp_start,
|
||||
@ -155,7 +156,7 @@ class Client(Connection):
|
||||
|
||||
def set_state(self, state):
|
||||
self.peername = tuple(state["address"]) if state["address"] else None
|
||||
self.alpn = state["alpn_proto_negotiated"]
|
||||
self.alpn = state["alpn"]
|
||||
self.cipher = state["cipher_name"]
|
||||
self.id = state["id"]
|
||||
self.sni = state["sni"]
|
||||
@ -169,6 +170,7 @@ class Client(Connection):
|
||||
self.error = state["error"]
|
||||
self.tls = state["tls"]
|
||||
self.certificate_list = [certs.Cert.from_state(x) for x in state["certificate_list"]]
|
||||
self.mitmcert = certs.Cert.from_state(state["mitmcert"]) if state["mitmcert"] is not None else None
|
||||
self.alpn_offers = state["alpn_offers"]
|
||||
self.cipher_list = state["cipher_list"]
|
||||
|
||||
@ -215,7 +217,7 @@ class Server(Connection):
|
||||
timestamp_tcp_setup: Optional[float] = None
|
||||
"""TCP ACK received"""
|
||||
|
||||
sni: Union[bytes, Literal[True], None] = True
|
||||
sni: Union[str, Literal[True], None] = True
|
||||
"""True: client SNI, False: no SNI, bytes: custom value"""
|
||||
via: Optional[server_spec.ServerSpec] = None
|
||||
|
||||
@ -226,7 +228,7 @@ class Server(Connection):
|
||||
def get_state(self):
|
||||
return {
|
||||
'address': self.address,
|
||||
'alpn_proto_negotiated': self.alpn,
|
||||
'alpn': self.alpn,
|
||||
'id': self.id,
|
||||
'ip_address': self.peername,
|
||||
'sni': self.sni,
|
||||
@ -257,7 +259,7 @@ class Server(Connection):
|
||||
|
||||
def set_state(self, state):
|
||||
self.address = tuple(state["address"]) if state["address"] else None
|
||||
self.alpn = state["alpn_proto_negotiated"]
|
||||
self.alpn = state["alpn"]
|
||||
self.id = state["id"]
|
||||
self.peername = tuple(state["ip_address"]) if state["ip_address"] else None
|
||||
self.sni = state["sni"]
|
||||
|
@ -435,7 +435,7 @@ class HttpStream(layer.Layer):
|
||||
|
||||
stack = tunnel.LayerStack()
|
||||
if self.context.server.via.scheme == "https":
|
||||
http_proxy.sni = self.context.server.via.address[0].encode()
|
||||
http_proxy.sni = self.context.server.via.address[0]
|
||||
stack /= tls.ServerTLSLayer(self.context, http_proxy)
|
||||
stack /= _upstream_proxy.HttpUpstreamProxy(self.context, http_proxy, True)
|
||||
|
||||
@ -635,7 +635,7 @@ class HttpLayer(layer.Layer):
|
||||
|
||||
context.server = Server(event.address)
|
||||
if event.tls:
|
||||
context.server.sni = event.address[0].encode()
|
||||
context.server.sni = event.address[0]
|
||||
|
||||
if event.via:
|
||||
assert event.via.scheme in ("http", "https")
|
||||
@ -643,7 +643,7 @@ class HttpLayer(layer.Layer):
|
||||
|
||||
if event.via.scheme == "https":
|
||||
http_proxy.alpn_offers = tls.HTTP_ALPNS
|
||||
http_proxy.sni = event.via.address[0].encode()
|
||||
http_proxy.sni = event.via.address[0]
|
||||
stack /= tls.ServerTLSLayer(context, http_proxy)
|
||||
|
||||
send_connect = not (self.mode == HTTPMode.upstream and not event.tls)
|
||||
|
@ -42,7 +42,7 @@ class ReverseProxy(DestinationKnown):
|
||||
|
||||
if spec.scheme not in ("http", "tcp"):
|
||||
if not self.context.options.keep_host_header:
|
||||
self.context.server.sni = spec.address[0].encode()
|
||||
self.context.server.sni = spec.address[0]
|
||||
self.child_layer = tls.ServerTLSLayer(self.context)
|
||||
else:
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
@ -4,7 +4,6 @@ from dataclasses import dataclass
|
||||
from typing import Iterator, Optional, Tuple
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy.net import tls as net_tls
|
||||
from mitmproxy.proxy import commands, events, layer, tunnel
|
||||
@ -184,6 +183,8 @@ class _TLSLayer(tunnel.TunnelLayer):
|
||||
err = f"OpenSSL {e!r}"
|
||||
return False, err
|
||||
else:
|
||||
# Here we set all attributes that are only known *after* the handshake.
|
||||
|
||||
# Get all peer certificates.
|
||||
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
|
||||
# If called on the client side, the stack also contains the peer's certificate; if called on the server
|
||||
@ -195,9 +196,8 @@ class _TLSLayer(tunnel.TunnelLayer):
|
||||
all_certs.insert(0, cert)
|
||||
|
||||
self.conn.timestamp_tls_setup = time.time()
|
||||
self.conn.sni = self.tls.get_servername()
|
||||
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
|
||||
self.conn.certificate_list = [certs.Cert(x) for x in all_certs]
|
||||
self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs]
|
||||
self.conn.cipher = self.tls.get_cipher_name()
|
||||
self.conn.tls_version = self.tls.get_protocol_version_name()
|
||||
if self.debug:
|
||||
@ -338,8 +338,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
if self.conn.sni:
|
||||
assert isinstance(self.conn.sni, bytes)
|
||||
dest = self.conn.sni.decode("idna")
|
||||
dest = self.conn.sni
|
||||
else:
|
||||
dest = human.format_address(self.context.server.address)
|
||||
if err.startswith("Cannot parse ClientHello"):
|
||||
|
@ -427,7 +427,8 @@ if __name__ == "__main__": # pragma: no cover
|
||||
tls_start.ssl_conn.set_accept_state()
|
||||
else:
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni)
|
||||
if tls_start.context.client.sni is not None:
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni.encode())
|
||||
|
||||
await SimpleConnectionHandler(reader, writer, opts, {
|
||||
"next_layer": next_layer,
|
||||
|
@ -158,7 +158,7 @@ def tclient_conn() -> context.Client:
|
||||
timestamp_end=946681206,
|
||||
sni="address",
|
||||
cipher_name="cipher",
|
||||
alpn_proto_negotiated=b"http/1.1",
|
||||
alpn=b"http/1.1",
|
||||
tls_version="TLSv1.2",
|
||||
tls_extensions=[(0x00, bytes.fromhex("000e00000b6578616d"))],
|
||||
state=0,
|
||||
@ -185,7 +185,7 @@ def tserver_conn() -> context.Server:
|
||||
timestamp_end=946681205,
|
||||
tls_established=True,
|
||||
sni="address",
|
||||
alpn_proto_negotiated=None,
|
||||
alpn=None,
|
||||
tls_version="TLSv1.2",
|
||||
via=None,
|
||||
state=0,
|
||||
|
@ -4,8 +4,7 @@ import urwid
|
||||
import mitmproxy.flow
|
||||
from mitmproxy import http
|
||||
from mitmproxy.tools.console import common, searchable
|
||||
from mitmproxy.utils import human
|
||||
from mitmproxy.utils import strutils
|
||||
from mitmproxy.utils import human, strutils
|
||||
|
||||
|
||||
def maybe_timestamp(base, attr):
|
||||
@ -40,64 +39,37 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
|
||||
text.append(urwid.Text([("head", "Metadata:")]))
|
||||
text.extend(common.format_keyvals(parts, indent=4))
|
||||
|
||||
if sc is not None and sc.ip_address:
|
||||
if sc is not None and sc.peername:
|
||||
text.append(urwid.Text([("head", "Server Connection:")]))
|
||||
parts = [
|
||||
("Address", human.format_address(sc.address)),
|
||||
]
|
||||
if sc.ip_address:
|
||||
parts.append(("Resolved Address", human.format_address(sc.ip_address)))
|
||||
if sc.peername:
|
||||
parts.append(("Resolved Address", human.format_address(sc.peername)))
|
||||
if resp:
|
||||
parts.append(("HTTP Version", resp.http_version))
|
||||
if sc.alpn_proto_negotiated:
|
||||
parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn_proto_negotiated)))
|
||||
if sc.alpn:
|
||||
parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn)))
|
||||
|
||||
text.extend(
|
||||
common.format_keyvals(parts, indent=4)
|
||||
)
|
||||
|
||||
c = sc.cert
|
||||
c = sc.certificate_list[0]
|
||||
if c:
|
||||
text.append(urwid.Text([("head", "Server Certificate:")]))
|
||||
parts = [
|
||||
("Type", "%s, %s bits" % c.keyinfo),
|
||||
("SHA1 digest", c.digest("sha1")),
|
||||
("SHA256 digest", c.fingerprint().hex()),
|
||||
("Valid to", str(c.notafter)),
|
||||
("Valid from", str(c.notbefore)),
|
||||
("Serial", str(c.serial)),
|
||||
(
|
||||
"Subject",
|
||||
urwid.BoxAdapter(
|
||||
urwid.ListBox(
|
||||
common.format_keyvals(
|
||||
c.subject,
|
||||
key_format="highlight"
|
||||
)
|
||||
),
|
||||
len(c.subject)
|
||||
)
|
||||
),
|
||||
(
|
||||
"Issuer",
|
||||
urwid.BoxAdapter(
|
||||
urwid.ListBox(
|
||||
common.format_keyvals(
|
||||
c.issuer,
|
||||
key_format="highlight"
|
||||
)
|
||||
),
|
||||
len(c.issuer)
|
||||
)
|
||||
)
|
||||
("Subject", urwid.Pile(common.format_keyvals(c.subject, key_format="highlight"))),
|
||||
("Issuer", urwid.Pile(common.format_keyvals(c.issuer, key_format="highlight")))
|
||||
]
|
||||
|
||||
if c.altnames:
|
||||
parts.append(
|
||||
(
|
||||
"Alt names",
|
||||
", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames)
|
||||
)
|
||||
)
|
||||
parts.append(("Alt names", ", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames)))
|
||||
text.extend(
|
||||
common.format_keyvals(parts, indent=4)
|
||||
)
|
||||
@ -106,19 +78,18 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
|
||||
text.append(urwid.Text([("head", "Client Connection:")]))
|
||||
|
||||
parts = [
|
||||
("Address", "{}:{}".format(cc.address[0], cc.address[1])),
|
||||
("Address", human.format_address(cc.peername)),
|
||||
]
|
||||
if req:
|
||||
parts.append(("HTTP Version", req.http_version))
|
||||
if cc.tls_version:
|
||||
parts.append(("TLS Version", cc.tls_version))
|
||||
if cc.sni:
|
||||
parts.append(("Server Name Indication",
|
||||
strutils.bytes_to_escaped_str(strutils.always_bytes(cc.sni, "idna"))))
|
||||
if cc.cipher_name:
|
||||
parts.append(("Cipher Name", cc.cipher_name))
|
||||
if cc.alpn_proto_negotiated:
|
||||
parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn_proto_negotiated)))
|
||||
parts.append(("Server Name Indication", cc.sni))
|
||||
if cc.cipher:
|
||||
parts.append(("Cipher Name", cc.cipher))
|
||||
if cc.alpn:
|
||||
parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn)))
|
||||
|
||||
text.extend(
|
||||
common.format_keyvals(parts, indent=4)
|
||||
|
@ -47,8 +47,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
||||
"timestamp_start": flow.client_conn.timestamp_start,
|
||||
"timestamp_tls_setup": flow.client_conn.timestamp_tls_setup,
|
||||
"timestamp_end": flow.client_conn.timestamp_end,
|
||||
# ideally idna, but we don't want errors
|
||||
"sni": always_str(flow.client_conn.sni, "ascii", "backslashreplace"),
|
||||
"sni": flow.client_conn.sni,
|
||||
"cipher_name": flow.client_conn.cipher,
|
||||
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
|
||||
"tls_version": flow.client_conn.tls_version,
|
||||
@ -61,18 +60,14 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
||||
"ip_address": flow.server_conn.peername,
|
||||
"source_address": flow.server_conn.sockname,
|
||||
"tls_established": flow.server_conn.tls_established,
|
||||
"alpn_proto_negotiated": always_str(flow.server_conn.alpn, "ascii", "backslashreplace"),
|
||||
"sni": flow.server_conn.sni,
|
||||
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
|
||||
"tls_version": flow.server_conn.tls_version,
|
||||
"timestamp_start": flow.server_conn.timestamp_start,
|
||||
"timestamp_tcp_setup": flow.server_conn.timestamp_tcp_setup,
|
||||
"timestamp_tls_setup": flow.server_conn.timestamp_tls_setup,
|
||||
"timestamp_end": flow.server_conn.timestamp_end,
|
||||
}
|
||||
if flow.server_conn.sni is True:
|
||||
f["server_conn"] = None
|
||||
else:
|
||||
# ideally idna, but we don't want errors
|
||||
f["server_conn"] = always_str(flow.server_conn.sni, "ascii", "backslashreplace")
|
||||
if flow.error:
|
||||
f["error"] = flow.error.get_state()
|
||||
|
||||
|
@ -7,7 +7,7 @@ MITMPROXY = "mitmproxy " + VERSION
|
||||
|
||||
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
|
||||
# for each change in the file format.
|
||||
FLOW_FORMAT_VERSION = 10
|
||||
FLOW_FORMAT_VERSION = 11
|
||||
|
||||
|
||||
def get_dev_version() -> str:
|
||||
|
1
setup.py
1
setup.py
@ -80,7 +80,6 @@ setup(
|
||||
"msgpack>=1.0.0, <1.1.0",
|
||||
"passlib>=1.6.5, <1.8",
|
||||
"protobuf>=3.14,<3.15",
|
||||
"pyasn1>=0.3.1,<0.5",
|
||||
"pyOpenSSL>=20.0,<20.1",
|
||||
"pyparsing>=2.4.2,<2.5",
|
||||
"pyperclip>=1.6.0,<1.9",
|
||||
|
21
test/helper_tools/memoryleak2.py
Normal file
21
test/helper_tools/memoryleak2.py
Normal file
@ -0,0 +1,21 @@
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
|
||||
import objgraph
|
||||
|
||||
from mitmproxy import certs
|
||||
|
||||
if __name__ == "__main__":
|
||||
store = certs.CertStore.from_store(path=Path("~/.mitmproxy/").expanduser(), basename="mitmproxy", key_size=2048)
|
||||
store.STORE_CAP = 5
|
||||
|
||||
for _ in range(5):
|
||||
store.get_cert(commonname=secrets.token_hex(16).encode(), sans=[], organization=None)
|
||||
|
||||
objgraph.show_growth()
|
||||
|
||||
for _ in range(20):
|
||||
store.get_cert(commonname=secrets.token_hex(16).encode(), sans=[], organization=None)
|
||||
|
||||
print("====")
|
||||
objgraph.show_growth()
|
@ -57,21 +57,21 @@ class TestTlsConfig:
|
||||
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
|
||||
|
||||
# Edge case first: We don't have _any_ idea about the server, so we just return "mitmproxy" as subject.
|
||||
cert, pkey, chainfile = ta.get_cert(ctx)
|
||||
assert cert.cn == b"mitmproxy"
|
||||
entry = ta.get_cert(ctx)
|
||||
assert entry.cert.cn == "mitmproxy"
|
||||
|
||||
# Here we have an existing server connection...
|
||||
ctx.server.address = ("server-address.example", 443)
|
||||
with open(tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.crt"), "rb") as f:
|
||||
ctx.server.certificate_list = [certs.Cert.from_pem(f.read())]
|
||||
cert, pkey, chainfile = ta.get_cert(ctx)
|
||||
assert cert.cn == b"example.mitmproxy.org"
|
||||
assert cert.altnames == [b"example.mitmproxy.org", b"server-address.example"]
|
||||
entry = ta.get_cert(ctx)
|
||||
assert entry.cert.cn == "example.mitmproxy.org"
|
||||
assert entry.cert.altnames == ["example.mitmproxy.org", "server-address.example"]
|
||||
|
||||
# And now we also incorporate SNI.
|
||||
ctx.client.sni = b"sni.example"
|
||||
cert, pkey, chainfile = ta.get_cert(ctx)
|
||||
assert cert.altnames == [b"example.mitmproxy.org", b"sni.example"]
|
||||
ctx.client.sni = "sni.example"
|
||||
entry = ta.get_cert(ctx)
|
||||
assert entry.cert.altnames == ["example.mitmproxy.org", "sni.example"]
|
||||
|
||||
def test_tls_clienthello(self):
|
||||
# only really testing for coverage here, there's no point in mirroring the individual conditions
|
||||
@ -222,7 +222,7 @@ class TestTlsConfig:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ca_expired(self, monkeypatch):
|
||||
monkeypatch.setattr(SSL.X509, "has_expired", lambda self: True)
|
||||
monkeypatch.setattr(certs.Cert, "has_expired", lambda self: True)
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
ta.configure(["confdir"])
|
||||
|
14
test/mitmproxy/data/confdir/mitmproxy-dhparam.pem
Normal file
14
test/mitmproxy/data/confdir/mitmproxy-dhparam.pem
Normal file
@ -0,0 +1,14 @@
|
||||
|
||||
-----BEGIN DH PARAMETERS-----
|
||||
MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3
|
||||
O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv
|
||||
j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ
|
||||
Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB
|
||||
chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC
|
||||
ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq
|
||||
o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX
|
||||
IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv
|
||||
A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8
|
||||
6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I
|
||||
rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI=
|
||||
-----END DH PARAMETERS-----
|
19
test/mitmproxy/net/data/dsa_cert.pem
Normal file
19
test/mitmproxy/net/data/dsa_cert.pem
Normal file
@ -0,0 +1,19 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDHTCCAtugAwIBAgIJAJNwd38WrLI/MAsGCWCGSAFlAwQDAjBFMQswCQYDVQQG
|
||||
EwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lk
|
||||
Z2l0cyBQdHkgTHRkMB4XDTIwMTIzMDE5MzkxMloXDTIxMDEyOTE5MzkxMlowRTEL
|
||||
MAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVy
|
||||
bmV0IFdpZGdpdHMgUHR5IEx0ZDCCAbcwggEsBgcqhkjOOAQBMIIBHwKBgQDL8YvJ
|
||||
wqiJEdWKh7KBbgyCxpWhZdxyLFMguXwI2iX3+n8m2vg/6kBLK/sbI3y3qTeQOvQ0
|
||||
V/WchWWAnGHT9CwB1MNhqMWd409ZAo49/m6IPYvoB1RctKyUE9E8jUad+jM2WGm9
|
||||
zCG+KPeYFebTsWiQhhIgx8vPM657x6wA5y+omQIVAP4uMtD8Xv+TZCeXTB4j2bi0
|
||||
yxjTAoGBAIRxOxIma4B4sEToblQjbwQh9gnWEhvvwAefu3Gcav12mRgYBVqPeskJ
|
||||
0zROLFr9ubGZo0F/xBGCpaVjY+nDbE8/MwURLBfBr2wQTDbXL/Z5Ea1wNIURsqdo
|
||||
VxXKADqPh/OfylObCONxF37RnSiABwlCPpIERj+g90oNSFLnvftiA4GEAAKBgC6X
|
||||
5sBsCV71iCPZSM1UCiSOnzCg/6w5nHPawwcdPUtU6/mpOdylfHlttr4jA1c2a3TB
|
||||
9ppxe1H0945nYFCHmU5zGtgwKPlNMS5GLrHtSbxiwB5dUm5AsNsLA0EVLF1Po7sI
|
||||
xQRWTWp57iLzRGdDoKFiPaW07g/UnNzd6i/RAfcWo1MwUTAdBgNVHQ4EFgQUtX4q
|
||||
D2f00P8e4zp9L+zOib+Fv10wHwYDVR0jBBgwFoAUtX4qD2f00P8e4zp9L+zOib+F
|
||||
v10wDwYDVR0TAQH/BAUwAwEB/zALBglghkgBZQMEAwIDLwAwLAIURQwecWi0wMjS
|
||||
EsmC5hxASeuH2UkCFG6zZfM0Sbms27vcrCfVEtxzhizU
|
||||
-----END CERTIFICATE-----
|
12
test/mitmproxy/net/data/ec_cert.pem
Normal file
12
test/mitmproxy/net/data/ec_cert.pem
Normal file
@ -0,0 +1,12 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIB1DCCAXqgAwIBAgIJALAxolM7r60uMAoGCCqGSM49BAMCMEUxCzAJBgNVBAYT
|
||||
AkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn
|
||||
aXRzIFB0eSBMdGQwHhcNMjAxMjMwMTkzOTIyWhcNMjEwMTI5MTkzOTIyWjBFMQsw
|
||||
CQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJu
|
||||
ZXQgV2lkZ2l0cyBQdHkgTHRkMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbfF4
|
||||
pQnC+cJvoOyp8g9scQsJLNBsd4Vd0ADCyLBx84B1WucmBfpqi+8ERo8e8Y899UKv
|
||||
KWov2eqOUmGiCLS3kaNTMFEwHQYDVR0OBBYEFHRF3WQJ6x05/c29Rv0mPZn6CwKb
|
||||
MB8GA1UdIwQYMBaAFHRF3WQJ6x05/c29Rv0mPZn6CwKbMA8GA1UdEwEB/wQFMAMB
|
||||
Af8wCgYIKoZIzj0EAwIDSAAwRQIgVDA9XrBWlQQteO0LHlkXVpZH0QXjtnQi426Y
|
||||
ZxGgG18CIQCVSh3iWEAsCXWWUwiVmqEbpWGNIwwpbK4wL7+5lvUvXw==
|
||||
-----END CERTIFICATE-----
|
@ -29,7 +29,7 @@ def test_sslkeylogfile(tdata, monkeypatch):
|
||||
Path(tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.pem")),
|
||||
Path(tdata.path("mitmproxy/net/data/dhparam.pem"))
|
||||
)
|
||||
cert, key, chain_file = store.get_cert(b"example.com", [], None)
|
||||
entry = store.get_cert("example.com", [], None)
|
||||
|
||||
cctx = tls.create_proxy_server_context(
|
||||
min_version=tls.DEFAULT_MIN_VERSION,
|
||||
@ -46,9 +46,9 @@ def test_sslkeylogfile(tdata, monkeypatch):
|
||||
min_version=tls.DEFAULT_MIN_VERSION,
|
||||
max_version=tls.DEFAULT_MAX_VERSION,
|
||||
cipher_list=None,
|
||||
cert=cert,
|
||||
key=key,
|
||||
chain_file=chain_file,
|
||||
cert=entry.cert,
|
||||
key=entry.privatekey,
|
||||
chain_file=entry.chain_file,
|
||||
alpn_select_callback=None,
|
||||
request_client_cert=False,
|
||||
extra_chain_certs=(),
|
||||
@ -105,7 +105,7 @@ class TestClientHello:
|
||||
)
|
||||
c = tls.ClientHello(data)
|
||||
assert repr(c)
|
||||
assert c.sni == b'example.com'
|
||||
assert c.sni == 'example.com'
|
||||
assert c.cipher_suites == [
|
||||
49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161,
|
||||
49171, 49162, 49172, 156, 157, 47, 53, 10
|
||||
|
@ -202,7 +202,7 @@ def test_reverse_proxy_tcp_over_tls(tctx: Context, monkeypatch, patch, connectio
|
||||
>> reply_tls_start()
|
||||
<< SendData(tctx.server, data)
|
||||
)
|
||||
assert tls.parse_client_hello(data()).sni == b"localhost"
|
||||
assert tls.parse_client_hello(data()).sni == "localhost"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("connection_strategy", ["eager", "lazy"])
|
||||
|
@ -76,7 +76,7 @@ def test_get_client_hello():
|
||||
|
||||
|
||||
def test_parse_client_hello():
|
||||
assert tls.parse_client_hello(client_hello_with_extensions).sni == b"example.com"
|
||||
assert tls.parse_client_hello(client_hello_with_extensions).sni == "example.com"
|
||||
assert tls.parse_client_hello(client_hello_with_extensions[:50]) is None
|
||||
with pytest.raises(ValueError):
|
||||
tls.parse_client_hello(client_hello_with_extensions[:183] + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00')
|
||||
@ -188,7 +188,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
# Set SNI
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni)
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode())
|
||||
|
||||
# Manually enable hostname verification.
|
||||
# Recent OpenSSL versions provide slightly nicer ways to do this, but they are not exposed in
|
||||
@ -202,7 +202,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
||||
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
|
||||
)
|
||||
SSL._openssl_assert(
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni, 0) == 1
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode(), 0) == 1
|
||||
)
|
||||
|
||||
return tutils.reply(*args, side_effect=make_conn, **kwargs)
|
||||
@ -227,7 +227,7 @@ class TestServerTLS:
|
||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||
tctx.server.state = ConnectionState.OPEN
|
||||
tctx.server.address = ("example.mitmproxy.org", 443)
|
||||
tctx.server.sni = b"example.mitmproxy.org"
|
||||
tctx.server.sni = "example.mitmproxy.org"
|
||||
|
||||
tssl = SSLTest(server_side=True)
|
||||
|
||||
@ -280,7 +280,7 @@ class TestServerTLS:
|
||||
"""If the certificate is not trusted, we should fail."""
|
||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||
tctx.server.address = ("wrong.host.mitmproxy.org", 443)
|
||||
tctx.server.sni = b"wrong.host.mitmproxy.org"
|
||||
tctx.server.sni = "wrong.host.mitmproxy.org"
|
||||
|
||||
tssl = SSLTest(server_side=True)
|
||||
|
||||
@ -316,7 +316,7 @@ class TestServerTLS:
|
||||
def test_remote_speaks_no_tls(self, tctx):
|
||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||
tctx.server.state = ConnectionState.OPEN
|
||||
tctx.server.sni = b"example.mitmproxy.org"
|
||||
tctx.server.sni = "example.mitmproxy.org"
|
||||
|
||||
# send ClientHello, receive random garbage back
|
||||
data = tutils.Placeholder(bytes)
|
||||
@ -345,7 +345,7 @@ def make_client_tls_layer(
|
||||
|
||||
# Add some server config, this is needed anyways.
|
||||
tctx.server.address = ("example.mitmproxy.org", 443)
|
||||
tctx.server.sni = b"example.mitmproxy.org"
|
||||
tctx.server.sni = "example.mitmproxy.org"
|
||||
|
||||
tssl_client = SSLTest(**kwargs)
|
||||
# Start handshake.
|
||||
|
@ -1,7 +1,12 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from mitmproxy import certs
|
||||
from ..conftest import skip_windows
|
||||
|
||||
|
||||
# class TestDNTree:
|
||||
# def test_simple(self):
|
||||
# d = certs.DNTree()
|
||||
@ -32,85 +37,87 @@ from ..conftest import skip_windows
|
||||
# assert d.get("com") == "foo"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tstore(tdata):
|
||||
return certs.CertStore.from_store(tdata.path("mitmproxy/data/confdir"), "mitmproxy", 2048)
|
||||
|
||||
|
||||
class TestCertStore:
|
||||
|
||||
def test_create_explicit(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
assert ca.get_cert(b"foo", [])
|
||||
assert ca.get_cert("foo", [])
|
||||
|
||||
ca2 = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
assert ca2.get_cert(b"foo", [])
|
||||
assert ca2.get_cert("foo", [])
|
||||
|
||||
assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
|
||||
assert ca.default_ca.serial == ca2.default_ca.serial
|
||||
|
||||
def test_create_no_common_name(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
assert ca.get_cert(None, [])[0].cn is None
|
||||
def test_create_no_common_name(self, tstore):
|
||||
assert tstore.get_cert(None, []).cert.cn is None
|
||||
|
||||
def test_create_tmp(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
assert ca.get_cert(b"foo.com", [])
|
||||
assert ca.get_cert(b"foo.com", [])
|
||||
assert ca.get_cert(b"*.foo.com", [])
|
||||
def test_chain_file(self, tdata, tmp_path):
|
||||
cert = Path(tdata.path("mitmproxy/data/confdir/mitmproxy-ca.pem")).read_bytes()
|
||||
(tmp_path / "mitmproxy-ca.pem").write_bytes(cert)
|
||||
ca = certs.CertStore.from_store(tmp_path, "mitmproxy", 2048)
|
||||
assert ca.default_chain_file is None
|
||||
|
||||
r = ca.get_cert(b"*.foo.com", [])
|
||||
assert r[1] == ca.default_privatekey
|
||||
(tmp_path / "mitmproxy-ca.pem").write_bytes(2 * cert)
|
||||
ca = certs.CertStore.from_store(tmp_path, "mitmproxy", 2048)
|
||||
assert ca.default_chain_file == (tmp_path / "mitmproxy-ca.pem")
|
||||
|
||||
def test_sans(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
c1 = ca.get_cert(b"foo.com", [b"*.bar.com"])
|
||||
ca.get_cert(b"foo.bar.com", [])
|
||||
def test_sans(self, tstore):
|
||||
c1 = tstore.get_cert("foo.com", ["*.bar.com"])
|
||||
tstore.get_cert("foo.bar.com", [])
|
||||
# assert c1 == c2
|
||||
c3 = ca.get_cert(b"bar.com", [])
|
||||
c3 = tstore.get_cert("bar.com", [])
|
||||
assert not c1 == c3
|
||||
|
||||
def test_sans_change(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
ca.get_cert(b"foo.com", [b"*.bar.com"])
|
||||
cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"])
|
||||
assert b"*.baz.com" in cert.altnames
|
||||
def test_sans_change(self, tstore):
|
||||
tstore.get_cert("foo.com", ["*.bar.com"])
|
||||
entry = tstore.get_cert("foo.bar.com", ["*.baz.com"])
|
||||
assert "*.baz.com" in entry.cert.altnames
|
||||
|
||||
def test_expire(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
ca.STORE_CAP = 3
|
||||
ca.get_cert(b"one.com", [])
|
||||
ca.get_cert(b"two.com", [])
|
||||
ca.get_cert(b"three.com", [])
|
||||
def test_expire(self, tstore):
|
||||
tstore.STORE_CAP = 3
|
||||
tstore.get_cert("one.com", [])
|
||||
tstore.get_cert("two.com", [])
|
||||
tstore.get_cert("three.com", [])
|
||||
|
||||
assert (b"one.com", ()) in ca.certs
|
||||
assert (b"two.com", ()) in ca.certs
|
||||
assert (b"three.com", ()) in ca.certs
|
||||
assert ("one.com", ()) in tstore.certs
|
||||
assert ("two.com", ()) in tstore.certs
|
||||
assert ("three.com", ()) in tstore.certs
|
||||
|
||||
ca.get_cert(b"one.com", [])
|
||||
tstore.get_cert("one.com", [])
|
||||
|
||||
assert (b"one.com", ()) in ca.certs
|
||||
assert (b"two.com", ()) in ca.certs
|
||||
assert (b"three.com", ()) in ca.certs
|
||||
assert ("one.com", ()) in tstore.certs
|
||||
assert ("two.com", ()) in tstore.certs
|
||||
assert ("three.com", ()) in tstore.certs
|
||||
|
||||
ca.get_cert(b"four.com", [])
|
||||
tstore.get_cert("four.com", [])
|
||||
|
||||
assert (b"one.com", ()) not in ca.certs
|
||||
assert (b"two.com", ()) in ca.certs
|
||||
assert (b"three.com", ()) in ca.certs
|
||||
assert (b"four.com", ()) in ca.certs
|
||||
assert ("one.com", ()) not in tstore.certs
|
||||
assert ("two.com", ()) in tstore.certs
|
||||
assert ("three.com", ()) in tstore.certs
|
||||
assert ("four.com", ()) in tstore.certs
|
||||
|
||||
def test_overrides(self, tmpdir):
|
||||
ca1 = certs.CertStore.from_store(str(tmpdir.join("ca1")), "test", 2048)
|
||||
ca2 = certs.CertStore.from_store(str(tmpdir.join("ca2")), "test", 2048)
|
||||
assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
|
||||
def test_overrides(self, tmp_path):
|
||||
ca1 = certs.CertStore.from_store(tmp_path / "ca1", "test", 2048)
|
||||
ca2 = certs.CertStore.from_store(tmp_path / "ca2", "test", 2048)
|
||||
assert not ca1.default_ca.serial == ca2.default_ca.serial
|
||||
|
||||
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
|
||||
dcp = tmpdir.join("dc")
|
||||
dcp.write(dc[0].to_pem())
|
||||
ca1.add_cert_file("foo.com", str(dcp))
|
||||
dc = ca2.get_cert("foo.com", ["sans.example.com"])
|
||||
dcp = tmp_path / "dc"
|
||||
dcp.write_bytes(dc.cert.to_pem())
|
||||
ca1.add_cert_file("foo.com", dcp)
|
||||
|
||||
ret = ca1.get_cert(b"foo.com", [])
|
||||
assert ret[0].serial == dc[0].serial
|
||||
ret = ca1.get_cert("foo.com", [])
|
||||
assert ret.cert.serial == dc.cert.serial
|
||||
|
||||
def test_create_dhparams(self, tmpdir):
|
||||
filename = str(tmpdir.join("dhparam.pem"))
|
||||
def test_create_dhparams(self, tmp_path):
|
||||
filename = tmp_path / "dhparam.pem"
|
||||
certs.CertStore.load_dhparam(filename)
|
||||
assert os.path.exists(filename)
|
||||
assert filename.exists()
|
||||
|
||||
@skip_windows
|
||||
def test_umask_secret(self, tmpdir):
|
||||
@ -123,22 +130,21 @@ class TestCertStore:
|
||||
|
||||
class TestDummyCert:
|
||||
|
||||
def test_with_ca(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
def test_with_ca(self, tstore):
|
||||
r = certs.dummy_cert(
|
||||
ca.default_privatekey,
|
||||
ca.default_ca,
|
||||
b"foo.com",
|
||||
[b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"],
|
||||
b"Foo Ltd."
|
||||
tstore.default_privatekey,
|
||||
tstore.default_ca._cert,
|
||||
"foo.com",
|
||||
["one.com", "two.com", "*.three.com", "127.0.0.1"],
|
||||
"Foo Ltd."
|
||||
)
|
||||
assert r.cn == b"foo.com"
|
||||
assert r.altnames == [b'one.com', b'two.com', b'*.three.com']
|
||||
assert r.organization == b"Foo Ltd."
|
||||
assert r.cn == "foo.com"
|
||||
assert r.altnames == ["one.com", "two.com", "*.three.com", "127.0.0.1"]
|
||||
assert r.organization == "Foo Ltd."
|
||||
|
||||
r = certs.dummy_cert(
|
||||
ca.default_privatekey,
|
||||
ca.default_ca,
|
||||
tstore.default_privatekey,
|
||||
tstore.default_ca._cert,
|
||||
None,
|
||||
[],
|
||||
None
|
||||
@ -154,16 +160,16 @@ class TestCert:
|
||||
with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
||||
d = f.read()
|
||||
c1 = certs.Cert.from_pem(d)
|
||||
assert c1.cn == b"google.com"
|
||||
assert c1.cn == "google.com"
|
||||
assert len(c1.altnames) == 436
|
||||
assert c1.organization == b"Google Inc"
|
||||
assert c1.organization == "Google Inc"
|
||||
|
||||
with open(tdata.path("mitmproxy/net/data/text_cert_2"), "rb") as f:
|
||||
d = f.read()
|
||||
c2 = certs.Cert.from_pem(d)
|
||||
assert c2.cn == b"www.inode.co.nz"
|
||||
assert c2.cn == "www.inode.co.nz"
|
||||
assert len(c2.altnames) == 2
|
||||
assert c2.digest("sha1")
|
||||
assert c2.fingerprint()
|
||||
assert c2.notbefore
|
||||
assert c2.notafter
|
||||
assert c2.subject
|
||||
@ -171,10 +177,30 @@ class TestCert:
|
||||
assert c2.serial
|
||||
assert c2.issuer
|
||||
assert c2.to_pem()
|
||||
assert c2.has_expired is not None
|
||||
assert c2.has_expired() is not None
|
||||
|
||||
assert c1 != c2
|
||||
|
||||
def test_convert(self, tdata):
|
||||
with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
||||
d = f.read()
|
||||
c = certs.Cert.from_pem(d)
|
||||
|
||||
assert c == certs.Cert.from_pem(c.to_pem())
|
||||
assert c == certs.Cert.from_state(c.get_state())
|
||||
assert c == certs.Cert.from_pyopenssl(c.to_pyopenssl())
|
||||
|
||||
@pytest.mark.parametrize("filename,name,bits", [
|
||||
("text_cert", "RSA", 1024),
|
||||
("dsa_cert.pem", "DSA", 1024),
|
||||
("ec_cert.pem", "EC (secp256r1)", 256),
|
||||
])
|
||||
def test_keyinfo(self, tdata, filename, name, bits):
|
||||
with open(tdata.path(f"mitmproxy/net/data/{filename}"), "rb") as f:
|
||||
d = f.read()
|
||||
c = certs.Cert.from_pem(d)
|
||||
assert c.keyinfo == (name, bits)
|
||||
|
||||
def test_err_broken_sans(self, tdata):
|
||||
with open(tdata.path("mitmproxy/net/data/text_cert_weird1"), "rb") as f:
|
||||
d = f.read()
|
||||
@ -182,12 +208,6 @@ class TestCert:
|
||||
# This breaks unless we ignore a decoding error.
|
||||
assert c.altnames is not None
|
||||
|
||||
def test_der(self, tdata):
|
||||
with open(tdata.path("mitmproxy/net/data/dercert"), "rb") as f:
|
||||
d = f.read()
|
||||
s = certs.Cert.from_der(d)
|
||||
assert s.cn
|
||||
|
||||
def test_state(self, tdata):
|
||||
with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
||||
d = f.read()
|
||||
@ -201,12 +221,13 @@ class TestCert:
|
||||
assert c == c2
|
||||
assert c is not c2
|
||||
|
||||
x = certs.Cert('')
|
||||
x.set_state(a)
|
||||
assert x == c
|
||||
c2.set_state(a)
|
||||
assert c == c2
|
||||
|
||||
def test_from_store_with_passphrase(self, tdata, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, b"password")
|
||||
ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), b"password")
|
||||
def test_from_store_with_passphrase(self, tdata, tstore):
|
||||
tstore.add_cert_file("unencrypted-no-pass", Path(tdata.path("mitmproxy/data/testkey.pem")), None)
|
||||
tstore.add_cert_file("unencrypted-pass", Path(tdata.path("mitmproxy/data/testkey.pem")), b"password")
|
||||
tstore.add_cert_file("encrypted-pass", Path(tdata.path("mitmproxy/data/mitmproxy.pem")), b"password")
|
||||
|
||||
assert ca.get_cert(b"foo", [])
|
||||
with pytest.raises(TypeError):
|
||||
tstore.add_cert_file("encrypted-no-pass", Path(tdata.path("mitmproxy/data/mitmproxy.pem")), None)
|
||||
|
@ -25,12 +25,10 @@ def test_format_keyvals():
|
||||
("ee", None),
|
||||
]
|
||||
)
|
||||
wrapped = urwid.BoxAdapter(
|
||||
urwid.ListBox(
|
||||
urwid.SimpleFocusListWalker(
|
||||
common.format_keyvals([("foo", "bar")])
|
||||
)
|
||||
), 1
|
||||
wrapped = urwid.Pile(
|
||||
urwid.SimpleFocusListWalker(
|
||||
common.format_keyvals([("foo", "bar")])
|
||||
)
|
||||
)
|
||||
assert wrapped.render((30,))
|
||||
assert common.format_keyvals(
|
||||
|
Loading…
Reference in New Issue
Block a user