diff --git a/CHANGELOG.md b/CHANGELOG.md index 08795483f..184c10ab8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ If you depend on these features, please raise your voice in ### Full Changelog * New Proxy Core based on sans-io pattern (@mhils) +* Use pyca/cryptography to generate certificates, not pyOpenSSL (@mhils) * Remove the legacy protocol stack (@Kriechi) * Remove all deprecated pathod and pathoc tools and modules (@Kriechi) * --- TODO: add new PRs above this line --- diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index 64263c9ea..16a025de2 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -66,7 +66,7 @@ class NextLayer: pass else: if sni: - hostnames.append(sni.decode("idna")) + hostnames.append(sni) if not hostnames: return False diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index 110490987..f1dc90c07 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -1,7 +1,8 @@ import os -from typing import List, Optional, Tuple, TypedDict, Any +from pathlib import Path +from typing import List, Optional, TypedDict, Any -from OpenSSL import SSL, crypto +from OpenSSL import SSL from mitmproxy import certs, ctx, exceptions from mitmproxy.net import tls as net_tls from mitmproxy.options import CONF_BASENAME @@ -115,7 +116,7 @@ class TlsConfig: client: context.Client = tls_start.context.client server: context.Server = tls_start.context.server - cert, key, chain_file = self.get_cert(tls_start.context) + entry = self.get_cert(tls_start.context) if not client.cipher_list and ctx.options.ciphers_client: client.cipher_list = ctx.options.ciphers_client.split(":") @@ -126,9 +127,9 @@ class TlsConfig: min_version=net_tls.Version[ctx.options.tls_version_client_min], max_version=net_tls.Version[ctx.options.tls_version_client_max], cipher_list=cipher_list, - cert=cert, - key=key, - chain_file=chain_file, + cert=entry.cert, + key=entry.privatekey, + chain_file=entry.chain_file, request_client_cert=False, alpn_select_callback=alpn_select_callback, extra_chain_certs=server.certificate_list, @@ -152,8 +153,7 @@ class TlsConfig: verify = net_tls.Verify.VERIFY_PEER if server.sni is True: - server.sni = client.sni or server.address[0].encode() - sni = server.sni or None # make sure that false-y values are None + server.sni = client.sni or server.address[0] if not server.alpn_offers: if client.alpn_offers: @@ -182,7 +182,7 @@ class TlsConfig: if os.path.isfile(client_certs): client_cert = client_certs else: - server_name: str = (server.sni or server.address[0].encode("idna")).decode() + server_name: str = server.sni or server.address[0] p = os.path.join(client_certs, f"{server_name}.pem") if os.path.isfile(p): client_cert = p @@ -192,7 +192,7 @@ class TlsConfig: max_version=net_tls.Version[ctx.options.tls_version_client_max], cipher_list=cipher_list, verify=verify, - sni=sni, + sni=server.sni, ca_path=ctx.options.ssl_verify_upstream_trusted_confdir, ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca, client_cert=client_cert, @@ -200,7 +200,8 @@ class TlsConfig: ) tls_start.ssl_conn = SSL.Connection(ssl_ctx) - tls_start.ssl_conn.set_tlsext_host_name(server.sni) + if server.sni: + tls_start.ssl_conn.set_tlsext_host_name(server.sni.encode()) tls_start.ssl_conn.set_connect_state() def running(self): @@ -232,8 +233,8 @@ class TlsConfig: if len(parts) == 1: parts = ["*", parts[0]] - cert = os.path.expanduser(parts[1]) - if not os.path.exists(cert): + cert = Path(parts[1]).expanduser() + if not cert.exists(): raise exceptions.OptionsError(f"Certificate file does not exist: {cert}") try: self.certstore.add_cert_file( @@ -241,16 +242,16 @@ class TlsConfig: cert, passphrase=ctx.options.cert_passphrase.encode("utf8") if ctx.options.cert_passphrase else None, ) - except crypto.Error as e: - raise exceptions.OptionsError(f"Invalid certificate format: {cert}") from e + except ValueError as e: + raise exceptions.OptionsError(f"Invalid certificate format for {cert}: {e}") from e - def get_cert(self, conn_context: context.Context) -> Tuple[certs.Cert, SSL.PKey, str]: + def get_cert(self, conn_context: context.Context) -> certs.CertStoreEntry: """ This function determines the Common Name (CN), Subject Alternative Names (SANs) and Organization Name our certificate should have and then fetches a matching cert from the certstore. """ - altnames: List[bytes] = [] - organization: Optional[bytes] = None + altnames: List[str] = [] + organization: Optional[str] = None # Use upstream certificate if available. if conn_context.server.certificate_list: @@ -265,11 +266,11 @@ class TlsConfig: if conn_context.client.sni: altnames.append(conn_context.client.sni) elif conn_context.server.address: - altnames.append(conn_context.server.address[0].encode("idna")) + altnames.append(conn_context.server.address[0]) # As a last resort, add *something* so that we have a certificate to serve. if not altnames: - altnames.append(b"mitmproxy") + altnames.append("mitmproxy") # only keep first occurrence of each hostname altnames = list(dict.fromkeys(altnames)) diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index ece050789..e5f121bde 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -1,23 +1,24 @@ -import os -import ssl -import time +import contextlib import datetime import ipaddress +import os import sys -import contextlib +from dataclasses import dataclass from pathlib import Path -from typing import Tuple, Optional, Union, Dict, List +from typing import Tuple, Optional, Union, Dict, List, NewType + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec +from cryptography.hazmat.primitives.serialization import pkcs12 +from cryptography.x509 import NameOID, ExtendedKeyUsageOID -from pyasn1.type import univ, constraint, char, namedtype, tag -from pyasn1.codec.der.decoder import decode -from pyasn1.error import PyAsn1Error import OpenSSL - from mitmproxy.coretypes import serializable # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 -DEFAULT_EXP = 94608000 # = 60 * 60 * 24 * 365 * 3 = 3 years -DEFAULT_EXP_DUMMY_CERT = 31536000 # = 60 * 60 * 24 * 365 = 1 year +CA_EXPIRY = datetime.timedelta(days=3 * 365) +CERT_EXPIRY = datetime.timedelta(days=365) # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = b""" @@ -37,51 +38,166 @@ rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= """ -def create_ca(organization: str, cn: str, exp: int, key_size: int) -> Tuple[OpenSSL.crypto.PKey, OpenSSL.crypto.X509]: - key = OpenSSL.crypto.PKey() - key.generate_key(OpenSSL.crypto.TYPE_RSA, key_size) - cert = OpenSSL.crypto.X509() - cert.set_serial_number(int(time.time() * 10000)) - cert.set_version(2) - cert.get_subject().CN = cn - cert.get_subject().O = organization - cert.gmtime_adj_notBefore(-3600 * 48) - cert.gmtime_adj_notAfter(exp) - cert.set_issuer(cert.get_subject()) - cert.set_pubkey(key) - cert.add_extensions([ - OpenSSL.crypto.X509Extension( - b"basicConstraints", - True, - b"CA:TRUE" - ), - OpenSSL.crypto.X509Extension( - b"nsCertType", - False, - b"sslCA" - ), - OpenSSL.crypto.X509Extension( - b"extendedKeyUsage", - False, - b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" - ), - OpenSSL.crypto.X509Extension( - b"keyUsage", - True, - b"keyCertSign, cRLSign" - ), - OpenSSL.crypto.X509Extension( - b"subjectKeyIdentifier", - False, - b"hash", - subject=cert - ), +class Cert(serializable.Serializable): + _cert: x509.Certificate + + def __init__(self, cert: x509.Certificate): + assert isinstance(cert, x509.Certificate) + self._cert = cert + + def __eq__(self, other): + return self.fingerprint() == other.fingerprint() + + @classmethod + def from_state(cls, state): + return cls.from_pem(state) + + def get_state(self): + return self.to_pem() + + def set_state(self, state): + self._cert = x509.load_pem_x509_certificate(state) + + @classmethod + def from_pem(cls, data: bytes) -> "Cert": + cert = x509.load_pem_x509_certificate(data) # type: ignore + return cls(cert) + + def to_pem(self) -> bytes: + return self._cert.public_bytes(serialization.Encoding.PEM) + + @classmethod + def from_pyopenssl(self, x509: OpenSSL.crypto.X509) -> "Cert": + return Cert(x509.to_cryptography()) + + def to_pyopenssl(self) -> OpenSSL.crypto.X509: + return OpenSSL.crypto.X509.from_cryptography(self._cert) + + def fingerprint(self) -> bytes: + return self._cert.fingerprint(hashes.SHA256()) + + @property + def issuer(self) -> List[Tuple[str, str]]: + return _name_to_keyval(self._cert.issuer) + + @property + def notbefore(self) -> datetime.datetime: + return self._cert.not_valid_before + + @property + def notafter(self) -> datetime.datetime: + return self._cert.not_valid_after + + def has_expired(self) -> bool: + return datetime.datetime.utcnow() > self._cert.not_valid_after + + @property + def subject(self) -> List[Tuple[str, str]]: + return _name_to_keyval(self._cert.subject) + + @property + def serial(self) -> int: + return self._cert.serial_number + + @property + def keyinfo(self): + public_key = self._cert.public_key() + if isinstance(public_key, rsa.RSAPublicKey): + return "RSA", public_key.key_size + if isinstance(public_key, dsa.DSAPublicKey): + return "DSA", public_key.key_size + if isinstance(public_key, ec.EllipticCurvePublicKey): + return f"EC ({public_key.curve.name})", public_key.key_size + return (public_key.__class__.__name__.replace("PublicKey", "").replace("_", ""), + getattr(public_key, "key_size", -1)) # pragma: no cover + + @property + def cn(self) -> Optional[str]: + attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME) + if attrs: + return attrs[0].value + return None + + @property + def organization(self) -> Optional[str]: + attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.ORGANIZATION_NAME) + if attrs: + return attrs[0].value + return None + + @property + def altnames(self) -> List[str]: + """ + Get all SubjectAlternativeName DNS altnames. + """ + try: + ext = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + except x509.ExtensionNotFound: + return [] + else: + return ( + ext.get_values_for_type(x509.DNSName) + + + [str(x) for x in ext.get_values_for_type(x509.IPAddress)] + ) + + +def _name_to_keyval(name: x509.Name) -> List[Tuple[str, str]]: + parts = [] + for rdn in name.rdns: + k, v = rdn.rfc4514_string().split("=", maxsplit=1) + parts.append((k, v)) + return parts + + +def create_ca( + organization: str, + cn: str, + key_size: int, +) -> Tuple[rsa.RSAPrivateKeyWithSerialization, x509.Certificate]: + now = datetime.datetime.now() + + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=key_size, + ) # type: ignore + name = x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, cn), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization) ]) - cert.sign(key, "sha256") - return key, cert + builder = x509.CertificateBuilder() + builder = builder.serial_number(x509.random_serial_number()) + builder = builder.subject_name(name) + builder = builder.not_valid_before(now - datetime.timedelta(days=2)) + builder = builder.not_valid_after(now + CA_EXPIRY) + builder = builder.issuer_name(name) + builder = builder.public_key(private_key.public_key()) + builder = builder.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True) + builder = builder.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False) + builder = builder.add_extension( + x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), critical=True) + builder = builder.add_extension(x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), critical=False) + cert = builder.sign(private_key=private_key, algorithm=hashes.SHA256()) # type: ignore + return private_key, cert -def dummy_cert(privkey, cacert, commonname, sans, organization): +def dummy_cert( + privkey: rsa.RSAPrivateKey, + cacert: x509.Certificate, + commonname: Optional[str], + sans: List[str], + organization: Optional[str] = None, +) -> Cert: """ Generates a dummy certificate. @@ -93,111 +209,114 @@ def dummy_cert(privkey, cacert, commonname, sans, organization): Returns cert if operation succeeded, None if not. """ - ss = [] - for i in sans: - try: - ipaddress.ip_address(i.decode("ascii")) - except ValueError: - ss.append(b"DNS:%s" % i) - else: - ss.append(b"IP:%s" % i) - ss = b", ".join(ss) + builder = x509.CertificateBuilder() + builder = builder.issuer_name(cacert.subject) + builder = builder.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False) + builder = builder.public_key(cacert.public_key()) - cert = OpenSSL.crypto.X509() - cert.gmtime_adj_notBefore(-3600 * 48) - cert.gmtime_adj_notAfter(DEFAULT_EXP_DUMMY_CERT) - cert.set_issuer(cacert.get_subject()) + now = datetime.datetime.now() + builder = builder.not_valid_before(now - datetime.timedelta(days=2)) + builder = builder.not_valid_after(now + CERT_EXPIRY) + + subject = [] is_valid_commonname = ( - commonname is not None and len(commonname) < 64 + commonname is not None and len(commonname) < 64 ) if is_valid_commonname: - cert.get_subject().CN = commonname + assert commonname is not None + subject.append(x509.NameAttribute(NameOID.COMMON_NAME, commonname)) if organization is not None: - cert.get_subject().O = organization - cert.set_serial_number(int(time.time() * 10000)) - if ss: - cert.set_version(2) - cert.add_extensions( - [OpenSSL.crypto.X509Extension( - b"subjectAltName", - # RFC 5280 §4.2.1.6: subjectAltName is critical if subject is empty. - not is_valid_commonname, - ss - )] - ) - cert.add_extensions([ - OpenSSL.crypto.X509Extension( - b"extendedKeyUsage", - False, - b"serverAuth,clientAuth" - ) - ]) - cert.set_pubkey(cacert.get_pubkey()) - cert.sign(privkey, "sha256") + assert organization is not None + subject.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization)) + builder = builder.subject_name(x509.Name(subject)) + builder = builder.serial_number(x509.random_serial_number()) + + ss: List[x509.GeneralName] = [] + for x in sans: + try: + ip = ipaddress.ip_address(x) + except ValueError: + ss.append(x509.DNSName(x)) + else: + ss.append(x509.IPAddress(ip)) + # RFC 5280 §4.2.1.6: subjectAltName is critical if subject is empty. + builder = builder.add_extension(x509.SubjectAlternativeName(ss), critical=not is_valid_commonname) + cert = builder.sign(private_key=privkey, algorithm=hashes.SHA256()) # type: ignore return Cert(cert) +@dataclass(frozen=True) class CertStoreEntry: - - def __init__(self, cert, privatekey, chain_file): - self.cert = cert - self.privatekey = privatekey - self.chain_file = chain_file + cert: Cert + privatekey: rsa.RSAPrivateKey + chain_file: Optional[Path] -TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs) -TGeneratedCertId = Tuple[Optional[bytes], Tuple[bytes, ...]] # (common_name, sans) +TCustomCertId = str # manually provided certs (e.g. mitmproxy's --certs) +TGeneratedCertId = Tuple[Optional[str], Tuple[str, ...]] # (common_name, sans) TCertId = Union[TCustomCertId, TGeneratedCertId] +DHParams = NewType("DHParams", bytes) + class CertStore: - """ Implements an in-memory certificate store. """ STORE_CAP = 100 + certs: Dict[TCertId, CertStoreEntry] + expire_queue: List[CertStoreEntry] def __init__( self, - default_privatekey, - default_ca, - default_chain_file, - dhparams): + default_privatekey: rsa.RSAPrivateKey, + default_ca: Cert, + default_chain_file: Optional[Path], + dhparams: DHParams + ): self.default_privatekey = default_privatekey self.default_ca = default_ca self.default_chain_file = default_chain_file self.dhparams = dhparams - self.certs: Dict[TCertId, CertStoreEntry] = {} + self.certs = {} self.expire_queue = [] - def expire(self, entry): + def expire(self, entry: CertStoreEntry) -> None: self.expire_queue.append(entry) if len(self.expire_queue) > self.STORE_CAP: d = self.expire_queue.pop(0) self.certs = {k: v for k, v in self.certs.items() if v != d} @staticmethod - def load_dhparam(path: str): - + def load_dhparam(path: Path) -> DHParams: # mitmproxy<=0.10 doesn't generate a dhparam file. # Create it now if necessary. - if not os.path.exists(path): - with open(path, "wb") as f: - f.write(DEFAULT_DHPARAM) + if not path.exists(): + path.write_bytes(DEFAULT_DHPARAM) - bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r") + # we could use cryptography for this, but it's unclear how to convert cryptography's object to pyOpenSSL's + # expected format. + bio = OpenSSL.SSL._lib.BIO_new_file(str(path).encode(sys.getfilesystemencoding()), b"r") if bio != OpenSSL.SSL._ffi.NULL: bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free) dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, - OpenSSL.SSL._ffi.NULL) + OpenSSL.SSL._ffi.NULL + ) dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free) return dh + raise RuntimeError("Error loading DH Params.") # pragma: no cover @classmethod - def from_store(cls, path: Union[Path, str], basename: str, key_size, passphrase: Optional[bytes] = None) -> "CertStore": + def from_store( + cls, + path: Union[Path, str], + basename: str, + key_size: int, + passphrase: Optional[bytes] = None + ) -> "CertStore": path = Path(path) ca_file = path / f"{basename}-ca.pem" dhparam_file = path / f"{basename}-dhparam.pem" @@ -208,17 +327,14 @@ class CertStore: @classmethod def from_files(cls, ca_file: Path, dhparam_file: Path, passphrase: Optional[bytes] = None) -> "CertStore": raw = ca_file.read_bytes() - ca = OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_PEM, - raw - ) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - raw, - passphrase - ) - dh = cls.load_dhparam(str(dhparam_file)) - return cls(key, ca, str(ca_file), dh) + key = load_pem_private_key(raw, passphrase) + ca = Cert.from_pem(raw) + dh = cls.load_dhparam(dhparam_file) + if raw.count(b"BEGIN CERTIFICATE") != 1: + chain_file: Optional[Path] = ca_file + else: + chain_file = None + return cls(key, ca, chain_file, dh) @staticmethod @contextlib.contextmanager @@ -236,74 +352,72 @@ class CertStore: os.umask(original_umask) @staticmethod - def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None, expiry=DEFAULT_EXP) -> None: + def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None) -> None: path.mkdir(parents=True, exist_ok=True) organization = organization or basename cn = cn or basename - key, ca = create_ca(organization=organization, cn=cn, exp=expiry, key_size=key_size) - # Dump the CA plus private key - with CertStore.umask_secret(), (path / f"{basename}-ca.pem").open("wb") as f: - f.write( - OpenSSL.crypto.dump_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - key)) - f.write( - OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - ca)) + key: rsa.RSAPrivateKeyWithSerialization + ca: x509.Certificate + key, ca = create_ca(organization=organization, cn=cn, key_size=key_size) + + # Dump the CA plus private key. + with CertStore.umask_secret(): + # PEM format + (path / f"{basename}-ca.pem").write_bytes( + key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + ca.public_bytes(serialization.Encoding.PEM) + ) + + # PKCS12 format for Windows devices + (path / f"{basename}-ca.p12").write_bytes( + pkcs12.serialize_key_and_certificates( # type: ignore + name=basename.encode(), + key=key, + cert=ca, + cas=None, + encryption_algorithm=serialization.NoEncryption(), + ) + ) # Dump the certificate in PEM format - with (path / f"{basename}-ca-cert.pem").open("wb") as f: - f.write( - OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - ca)) - + pem_cert = ca.public_bytes(serialization.Encoding.PEM) + (path / f"{basename}-ca-cert.pem").write_bytes(pem_cert) # Create a .cer file with the same contents for Android - with (path / f"{basename}-ca-cert.cer").open("wb") as f: - f.write( - OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - ca)) + (path / f"{basename}-ca-cert.cer").write_bytes(pem_cert) # Dump the certificate in PKCS12 format for Windows devices - with (path / f"{basename}-ca-cert.p12").open("wb") as f: - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - f.write(p12.export()) - - # Dump the certificate and key in a PKCS12 format for Windows devices - with CertStore.umask_secret(), (path / f"{basename}-ca.p12").open("wb") as f: - p12 = OpenSSL.crypto.PKCS12() - p12.set_certificate(ca) - p12.set_privatekey(key) - f.write(p12.export()) - - with (path / f"{basename}-dhparam.pem").open("wb") as f: - f.write(DEFAULT_DHPARAM) - - def add_cert_file(self, spec: str, path: str, passphrase: Optional[bytes] = None) -> None: - with open(path, "rb") as f: - raw = f.read() - cert = Cert( - OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_PEM, - raw)) - try: - privatekey = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - raw, - passphrase) - except Exception: - privatekey = self.default_privatekey - self.add_cert( - CertStoreEntry(cert, privatekey, path), - spec.encode("idna") + (path / f"{basename}-ca-cert.p12").write_bytes( + pkcs12.serialize_key_and_certificates( # type: ignore + name=basename.encode(), + key=None, + cert=ca, + cas=None, + encryption_algorithm=serialization.NoEncryption(), + ) ) - def add_cert(self, entry: CertStoreEntry, *names: bytes): + (path / f"{basename}-dhparam.pem").write_bytes(DEFAULT_DHPARAM) + + def add_cert_file(self, spec: str, path: Path, passphrase: Optional[bytes] = None) -> None: + raw = path.read_bytes() + cert = Cert.from_pem(raw) + try: + key = load_pem_private_key(raw, password=passphrase) + except ValueError: + key = self.default_privatekey + + self.add_cert( + CertStoreEntry(cert, key, path), + spec + ) + + def add_cert(self, entry: CertStoreEntry, *names: str) -> None: """ Adds a cert to the certstore. We register the CN in the cert plus any SANs, and also the list of names provided as an argument. @@ -316,26 +430,24 @@ class CertStore: self.certs[i] = entry @staticmethod - def asterisk_forms(dn: bytes) -> List[bytes]: + def asterisk_forms(dn: str) -> List[str]: """ Return all asterisk forms for a domain. For example, for www.example.com this will return [b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted. """ - parts = dn.split(b".") + parts = dn.split(".") ret = [dn] for i in range(1, len(parts)): - ret.append(b"*." + b".".join(parts[i:])) + ret.append("*." + ".".join(parts[i:])) return ret def get_cert( self, - commonname: Optional[bytes], - sans: List[bytes], - organization: Optional[bytes] = None - ) -> Tuple["Cert", OpenSSL.SSL.PKey, str]: + commonname: Optional[str], + sans: List[str], + organization: Optional[str] = None + ) -> CertStoreEntry: """ - Returns an (cert, privkey, cert_chain) tuple. - commonname: Common name for the generated certificate. Must be a valid, plain-ASCII, IDNA-encoded domain name. @@ -349,7 +461,7 @@ class CertStore: potential_keys.extend(self.asterisk_forms(commonname)) for s in sans: potential_keys.extend(self.asterisk_forms(s)) - potential_keys.append(b"*") + potential_keys.append("*") potential_keys.append((commonname, tuple(sans))) name = next( @@ -362,147 +474,27 @@ class CertStore: entry = CertStoreEntry( cert=dummy_cert( self.default_privatekey, - self.default_ca, + self.default_ca._cert, commonname, sans, organization), privatekey=self.default_privatekey, - chain_file=self.default_chain_file) + chain_file=self.default_chain_file + ) self.certs[(commonname, tuple(sans))] = entry self.expire(entry) - return entry.cert, entry.privatekey, entry.chain_file + return entry -class _GeneralName(univ.Choice): - # We only care about dNSName and iPAddress - componentType = namedtype.NamedTypes( - namedtype.NamedType('dNSName', char.IA5String().subtype( - implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) - )), - namedtype.NamedType('iPAddress', univ.OctetString().subtype( - implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 7) - )), - ) - - -class _GeneralNames(univ.SequenceOf): - componentType = _GeneralName() - sizeSpec = univ.SequenceOf.sizeSpec + \ - constraint.ValueSizeConstraint(1, 1024) - - -class Cert(serializable.Serializable): - - def __init__(self, cert): - """ - Returns a (common name, [subject alternative names]) tuple. - """ - self.x509 = cert - - def __eq__(self, other): - return self.digest("sha256") == other.digest("sha256") - - def get_state(self): - return self.to_pem() - - def set_state(self, state): - self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) - - @classmethod - def from_state(cls, state): - return cls.from_pem(state) - - @classmethod - def from_pem(cls, txt): - x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) - return cls(x509) - - @classmethod - def from_der(cls, der): - pem = ssl.DER_cert_to_PEM_cert(der) - return cls.from_pem(pem) - - def to_pem(self): - return OpenSSL.crypto.dump_certificate( - OpenSSL.crypto.FILETYPE_PEM, - self.x509) - - def digest(self, name): - return self.x509.digest(name) - - @property - def issuer(self): - return self.x509.get_issuer().get_components() - - @property - def notbefore(self): - t = self.x509.get_notBefore() - return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") - - @property - def notafter(self): - t = self.x509.get_notAfter() - return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ") - - @property - def has_expired(self): - return self.x509.has_expired() - - @property - def subject(self): - return self.x509.get_subject().get_components() - - @property - def serial(self): - return self.x509.get_serial_number() - - @property - def keyinfo(self): - pk = self.x509.get_pubkey() - types = { - OpenSSL.crypto.TYPE_RSA: "RSA", - OpenSSL.crypto.TYPE_DSA: "DSA", - } - return ( - types.get(pk.type(), "UNKNOWN"), - pk.bits() - ) - - @property - def cn(self) -> Optional[bytes]: - c = None - for i in self.subject: - if i[0] == b"CN": - c = i[1] - return c - - @property - def organization(self) -> Optional[bytes]: - c = None - for i in self.subject: - if i[0] == b"O": - c = i[1] - return c - - @property - def altnames(self) -> List[bytes]: - """ - Returns: - All DNS altnames. - """ - # tcp.TCPClient.convert_to_tls assumes that this property only contains DNS altnames for hostname verification. - altnames = [] - for i in range(self.x509.get_extension_count()): - ext = self.x509.get_extension(i) - if ext.get_short_name() == b"subjectAltName": - try: - dec = decode(ext.get_data(), asn1Spec=_GeneralNames()) - except PyAsn1Error: - continue - for x in dec[0]: - if x[0].hasValue(): - e = x[0].asOctets() - altnames.append(e) - - return altnames +def load_pem_private_key(data: bytes, password: Optional[bytes]) -> rsa.RSAPrivateKey: + """ + like cryptography's load_pem_private_key, but silently falls back to not using a password + if the private key is unencrypted. + """ + try: + return serialization.load_pem_private_key(data, password) # type: ignore + except TypeError: + if password is not None: + return load_pem_private_key(data, None) + raise diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py index 810003d57..2f098758f 100644 --- a/mitmproxy/io/compat.py +++ b/mitmproxy/io/compat.py @@ -230,6 +230,23 @@ def convert_9_10(data): return data +def convert_10_11(data): + data["version"] = 11 + + def conv_conn(conn): + conn["sni"] = strutils.always_str(conn["sni"], "ascii", "backslashreplace") + conn["alpn"] = conn.pop("alpn_proto_negotiated") + conn["alpn_offers"] = conn["alpn_offers"] or [] + conn["cipher_list"] = conn["cipher_list"] or [] + + conv_conn(data["client_conn"]) + conv_conn(data["server_conn"]) + if data["server_conn"]["via"]: + conv_conn(data["server_conn"]["via"]) + + return data + + def _convert_dict_keys(o: Any) -> Any: if isinstance(o, dict): return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} @@ -287,6 +304,7 @@ converters = { 7: convert_7_8, 8: convert_8_9, 9: convert_9_10, + 10: convert_10_11, } diff --git a/mitmproxy/net/tls.py b/mitmproxy/net/tls.py index abb5a5cdd..5af4aaa26 100644 --- a/mitmproxy/net/tls.py +++ b/mitmproxy/net/tls.py @@ -6,9 +6,10 @@ from pathlib import Path from typing import Iterable, Callable, Optional, Tuple, List, Any, BinaryIO import certifi +from cryptography.hazmat.primitives.asymmetric import rsa from kaitaistruct import KaitaiStream -from OpenSSL import SSL +from OpenSSL import SSL, crypto from mitmproxy import certs from mitmproxy.contrib.kaitaistruct import tls_client_hello from mitmproxy.net import check @@ -129,7 +130,7 @@ def create_proxy_server_context( max_version: Version, cipher_list: Optional[Iterable[str]], verify: Verify, - sni: Optional[bytes], + sni: Optional[str], ca_path: Optional[str], ca_pemfile: Optional[str], client_cert: Optional[str], @@ -147,6 +148,7 @@ def create_proxy_server_context( context.set_verify(verify.value, None) if sni is not None: + assert isinstance(sni, str) # Manually enable hostname verification on the context object. # https://wiki.openssl.org/index.php/Hostname_validation param = SSL._lib.SSL_CTX_get0_param(context._context) @@ -157,7 +159,7 @@ def create_proxy_server_context( SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT ) SSL._openssl_assert( - SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni, 0) == 1 + SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni.encode(), 0) == 1 ) if ca_path is None and ca_pemfile is None: @@ -188,12 +190,12 @@ def create_client_proxy_context( max_version: Version, cipher_list: Optional[Iterable[str]], cert: certs.Cert, - key: SSL.PKey, - chain_file: str, + key: rsa.RSAPrivateKey, + chain_file: Optional[Path], alpn_select_callback: Optional[Callable[[SSL.Connection, List[bytes]], Any]], request_client_cert: bool, extra_chain_certs: Iterable[certs.Cert], - dhparams, + dhparams: certs.DHParams, ) -> SSL.Context: context: SSL.Context = _create_ssl_context( method=Method.TLS_SERVER_METHOD, @@ -202,12 +204,13 @@ def create_client_proxy_context( cipher_list=cipher_list, ) - context.use_certificate(cert.x509) - context.use_privatekey(key) - try: - context.load_verify_locations(chain_file, None) - except SSL.Error as e: - raise RuntimeError(f"Cannot load certificate chain ({chain_file}).") from e + context.use_certificate(cert.to_pyopenssl()) + context.use_privatekey(crypto.PKey.from_cryptography_key(key)) + if chain_file is not None: + try: + context.load_verify_locations(str(chain_file), None) + except SSL.Error as e: + raise RuntimeError(f"Cannot load certificate chain ({chain_file}).") from e if alpn_select_callback is not None: assert callable(alpn_select_callback) @@ -227,7 +230,7 @@ def create_client_proxy_context( context.set_verify(Verify.VERIFY_NONE.value, None) for i in extra_chain_certs: - context.add_extra_chain_cert(i.x509) + context.add_extra_chain_cert(i._cert) if dhparams: SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) @@ -277,7 +280,7 @@ class ClientHello: return self._client_hello.cipher_suites.cipher_suites @property - def sni(self) -> Optional[bytes]: + def sni(self) -> Optional[str]: if self._client_hello.extensions: for extension in self._client_hello.extensions.extensions: is_valid_sni_extension = ( @@ -287,11 +290,11 @@ class ClientHello: check.is_valid_host(extension.body.server_names[0].host_name) ) if is_valid_sni_extension: - return extension.body.server_names[0].host_name + return extension.body.server_names[0].host_name.decode("ascii") return None @property - def alpn_protocols(self): + def alpn_protocols(self) -> List[bytes]: if self._client_hello.extensions: for extension in self._client_hello.extensions.extensions: if extension.type == 0x10: diff --git a/mitmproxy/proxy/context.py b/mitmproxy/proxy/context.py index e9d47a33f..785830b4b 100644 --- a/mitmproxy/proxy/context.py +++ b/mitmproxy/proxy/context.py @@ -65,7 +65,7 @@ class Connection(serializable.Serializable, metaclass=ABCMeta): cipher_list: Sequence[str] = () """Ciphers accepted by the proxy server on this connection.""" tls_version: Optional[str] = None - sni: Union[bytes, Literal[True], None] + sni: Union[str, Literal[True], None] timestamp_end: Optional[float] = None """Connection end timestamp""" @@ -109,7 +109,8 @@ class Client(Connection): timestamp_start: float """TCP SYN received""" - sni: Union[bytes, None] = None + mitmcert: Optional[certs.Cert] = None + sni: Union[str, None] = None def __init__(self, peername, sockname, timestamp_start): self.id = str(uuid.uuid4()) @@ -122,10 +123,10 @@ class Client(Connection): # This means we need to add all new fields to the old implementation. return { 'address': self.peername, - 'alpn_proto_negotiated': self.alpn, + 'alpn': self.alpn, 'cipher_name': self.cipher, 'id': self.id, - 'mitmcert': None, + 'mitmcert': self.mitmcert.get_state() if self.mitmcert is not None else None, 'sni': self.sni, 'timestamp_end': self.timestamp_end, 'timestamp_start': self.timestamp_start, @@ -155,7 +156,7 @@ class Client(Connection): def set_state(self, state): self.peername = tuple(state["address"]) if state["address"] else None - self.alpn = state["alpn_proto_negotiated"] + self.alpn = state["alpn"] self.cipher = state["cipher_name"] self.id = state["id"] self.sni = state["sni"] @@ -169,6 +170,7 @@ class Client(Connection): self.error = state["error"] self.tls = state["tls"] self.certificate_list = [certs.Cert.from_state(x) for x in state["certificate_list"]] + self.mitmcert = certs.Cert.from_state(state["mitmcert"]) if state["mitmcert"] is not None else None self.alpn_offers = state["alpn_offers"] self.cipher_list = state["cipher_list"] @@ -215,7 +217,7 @@ class Server(Connection): timestamp_tcp_setup: Optional[float] = None """TCP ACK received""" - sni: Union[bytes, Literal[True], None] = True + sni: Union[str, Literal[True], None] = True """True: client SNI, False: no SNI, bytes: custom value""" via: Optional[server_spec.ServerSpec] = None @@ -226,7 +228,7 @@ class Server(Connection): def get_state(self): return { 'address': self.address, - 'alpn_proto_negotiated': self.alpn, + 'alpn': self.alpn, 'id': self.id, 'ip_address': self.peername, 'sni': self.sni, @@ -257,7 +259,7 @@ class Server(Connection): def set_state(self, state): self.address = tuple(state["address"]) if state["address"] else None - self.alpn = state["alpn_proto_negotiated"] + self.alpn = state["alpn"] self.id = state["id"] self.peername = tuple(state["ip_address"]) if state["ip_address"] else None self.sni = state["sni"] diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index bc2d87ac5..dc9d501cc 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -435,7 +435,7 @@ class HttpStream(layer.Layer): stack = tunnel.LayerStack() if self.context.server.via.scheme == "https": - http_proxy.sni = self.context.server.via.address[0].encode() + http_proxy.sni = self.context.server.via.address[0] stack /= tls.ServerTLSLayer(self.context, http_proxy) stack /= _upstream_proxy.HttpUpstreamProxy(self.context, http_proxy, True) @@ -635,7 +635,7 @@ class HttpLayer(layer.Layer): context.server = Server(event.address) if event.tls: - context.server.sni = event.address[0].encode() + context.server.sni = event.address[0] if event.via: assert event.via.scheme in ("http", "https") @@ -643,7 +643,7 @@ class HttpLayer(layer.Layer): if event.via.scheme == "https": http_proxy.alpn_offers = tls.HTTP_ALPNS - http_proxy.sni = event.via.address[0].encode() + http_proxy.sni = event.via.address[0] stack /= tls.ServerTLSLayer(context, http_proxy) send_connect = not (self.mode == HTTPMode.upstream and not event.tls) diff --git a/mitmproxy/proxy/layers/modes.py b/mitmproxy/proxy/layers/modes.py index 4ff11bb87..3ea7bb9bb 100644 --- a/mitmproxy/proxy/layers/modes.py +++ b/mitmproxy/proxy/layers/modes.py @@ -42,7 +42,7 @@ class ReverseProxy(DestinationKnown): if spec.scheme not in ("http", "tcp"): if not self.context.options.keep_host_header: - self.context.server.sni = spec.address[0].encode() + self.context.server.sni = spec.address[0] self.child_layer = tls.ServerTLSLayer(self.context) else: self.child_layer = layer.NextLayer(self.context) diff --git a/mitmproxy/proxy/layers/tls.py b/mitmproxy/proxy/layers/tls.py index 6f2aaa47b..1b2e8cbc6 100644 --- a/mitmproxy/proxy/layers/tls.py +++ b/mitmproxy/proxy/layers/tls.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Iterator, Optional, Tuple from OpenSSL import SSL - from mitmproxy import certs from mitmproxy.net import tls as net_tls from mitmproxy.proxy import commands, events, layer, tunnel @@ -184,6 +183,8 @@ class _TLSLayer(tunnel.TunnelLayer): err = f"OpenSSL {e!r}" return False, err else: + # Here we set all attributes that are only known *after* the handshake. + # Get all peer certificates. # https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html # If called on the client side, the stack also contains the peer's certificate; if called on the server @@ -195,9 +196,8 @@ class _TLSLayer(tunnel.TunnelLayer): all_certs.insert(0, cert) self.conn.timestamp_tls_setup = time.time() - self.conn.sni = self.tls.get_servername() self.conn.alpn = self.tls.get_alpn_proto_negotiated() - self.conn.certificate_list = [certs.Cert(x) for x in all_certs] + self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs] self.conn.cipher = self.tls.get_cipher_name() self.conn.tls_version = self.tls.get_protocol_version_name() if self.debug: @@ -338,8 +338,7 @@ class ClientTLSLayer(_TLSLayer): def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]: if self.conn.sni: - assert isinstance(self.conn.sni, bytes) - dest = self.conn.sni.decode("idna") + dest = self.conn.sni else: dest = human.format_address(self.context.server.address) if err.startswith("Cannot parse ClientHello"): diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 13a6e777f..b2b4607eb 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -427,7 +427,8 @@ if __name__ == "__main__": # pragma: no cover tls_start.ssl_conn.set_accept_state() else: tls_start.ssl_conn.set_connect_state() - tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni) + if tls_start.context.client.sni is not None: + tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni.encode()) await SimpleConnectionHandler(reader, writer, opts, { "next_layer": next_layer, diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index 1a2f6edaf..0a0755ad1 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -158,7 +158,7 @@ def tclient_conn() -> context.Client: timestamp_end=946681206, sni="address", cipher_name="cipher", - alpn_proto_negotiated=b"http/1.1", + alpn=b"http/1.1", tls_version="TLSv1.2", tls_extensions=[(0x00, bytes.fromhex("000e00000b6578616d"))], state=0, @@ -185,7 +185,7 @@ def tserver_conn() -> context.Server: timestamp_end=946681205, tls_established=True, sni="address", - alpn_proto_negotiated=None, + alpn=None, tls_version="TLSv1.2", via=None, state=0, diff --git a/mitmproxy/tools/console/flowdetailview.py b/mitmproxy/tools/console/flowdetailview.py index b61f2476a..aa3d449c8 100644 --- a/mitmproxy/tools/console/flowdetailview.py +++ b/mitmproxy/tools/console/flowdetailview.py @@ -4,8 +4,7 @@ import urwid import mitmproxy.flow from mitmproxy import http from mitmproxy.tools.console import common, searchable -from mitmproxy.utils import human -from mitmproxy.utils import strutils +from mitmproxy.utils import human, strutils def maybe_timestamp(base, attr): @@ -40,64 +39,37 @@ def flowdetails(state, flow: mitmproxy.flow.Flow): text.append(urwid.Text([("head", "Metadata:")])) text.extend(common.format_keyvals(parts, indent=4)) - if sc is not None and sc.ip_address: + if sc is not None and sc.peername: text.append(urwid.Text([("head", "Server Connection:")])) parts = [ ("Address", human.format_address(sc.address)), ] - if sc.ip_address: - parts.append(("Resolved Address", human.format_address(sc.ip_address))) + if sc.peername: + parts.append(("Resolved Address", human.format_address(sc.peername))) if resp: parts.append(("HTTP Version", resp.http_version)) - if sc.alpn_proto_negotiated: - parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn_proto_negotiated))) + if sc.alpn: + parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn))) text.extend( common.format_keyvals(parts, indent=4) ) - c = sc.cert + c = sc.certificate_list[0] if c: text.append(urwid.Text([("head", "Server Certificate:")])) parts = [ ("Type", "%s, %s bits" % c.keyinfo), - ("SHA1 digest", c.digest("sha1")), + ("SHA256 digest", c.fingerprint().hex()), ("Valid to", str(c.notafter)), ("Valid from", str(c.notbefore)), ("Serial", str(c.serial)), - ( - "Subject", - urwid.BoxAdapter( - urwid.ListBox( - common.format_keyvals( - c.subject, - key_format="highlight" - ) - ), - len(c.subject) - ) - ), - ( - "Issuer", - urwid.BoxAdapter( - urwid.ListBox( - common.format_keyvals( - c.issuer, - key_format="highlight" - ) - ), - len(c.issuer) - ) - ) + ("Subject", urwid.Pile(common.format_keyvals(c.subject, key_format="highlight"))), + ("Issuer", urwid.Pile(common.format_keyvals(c.issuer, key_format="highlight"))) ] if c.altnames: - parts.append( - ( - "Alt names", - ", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames) - ) - ) + parts.append(("Alt names", ", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames))) text.extend( common.format_keyvals(parts, indent=4) ) @@ -106,19 +78,18 @@ def flowdetails(state, flow: mitmproxy.flow.Flow): text.append(urwid.Text([("head", "Client Connection:")])) parts = [ - ("Address", "{}:{}".format(cc.address[0], cc.address[1])), + ("Address", human.format_address(cc.peername)), ] if req: parts.append(("HTTP Version", req.http_version)) if cc.tls_version: parts.append(("TLS Version", cc.tls_version)) if cc.sni: - parts.append(("Server Name Indication", - strutils.bytes_to_escaped_str(strutils.always_bytes(cc.sni, "idna")))) - if cc.cipher_name: - parts.append(("Cipher Name", cc.cipher_name)) - if cc.alpn_proto_negotiated: - parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn_proto_negotiated))) + parts.append(("Server Name Indication", cc.sni)) + if cc.cipher: + parts.append(("Cipher Name", cc.cipher)) + if cc.alpn: + parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn))) text.extend( common.format_keyvals(parts, indent=4) diff --git a/mitmproxy/tools/web/app.py b/mitmproxy/tools/web/app.py index 54942fc38..c24e456c8 100644 --- a/mitmproxy/tools/web/app.py +++ b/mitmproxy/tools/web/app.py @@ -47,8 +47,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict: "timestamp_start": flow.client_conn.timestamp_start, "timestamp_tls_setup": flow.client_conn.timestamp_tls_setup, "timestamp_end": flow.client_conn.timestamp_end, - # ideally idna, but we don't want errors - "sni": always_str(flow.client_conn.sni, "ascii", "backslashreplace"), + "sni": flow.client_conn.sni, "cipher_name": flow.client_conn.cipher, "alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"), "tls_version": flow.client_conn.tls_version, @@ -61,18 +60,14 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict: "ip_address": flow.server_conn.peername, "source_address": flow.server_conn.sockname, "tls_established": flow.server_conn.tls_established, - "alpn_proto_negotiated": always_str(flow.server_conn.alpn, "ascii", "backslashreplace"), + "sni": flow.server_conn.sni, + "alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"), "tls_version": flow.server_conn.tls_version, "timestamp_start": flow.server_conn.timestamp_start, "timestamp_tcp_setup": flow.server_conn.timestamp_tcp_setup, "timestamp_tls_setup": flow.server_conn.timestamp_tls_setup, "timestamp_end": flow.server_conn.timestamp_end, } - if flow.server_conn.sni is True: - f["server_conn"] = None - else: - # ideally idna, but we don't want errors - f["server_conn"] = always_str(flow.server_conn.sni, "ascii", "backslashreplace") if flow.error: f["error"] = flow.error.get_state() diff --git a/mitmproxy/version.py b/mitmproxy/version.py index a559d16ba..883469ddd 100644 --- a/mitmproxy/version.py +++ b/mitmproxy/version.py @@ -7,7 +7,7 @@ MITMPROXY = "mitmproxy " + VERSION # Serialization format version. This is displayed nowhere, it just needs to be incremented by one # for each change in the file format. -FLOW_FORMAT_VERSION = 10 +FLOW_FORMAT_VERSION = 11 def get_dev_version() -> str: diff --git a/setup.py b/setup.py index 4b6435e02..84b83d17e 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,6 @@ setup( "msgpack>=1.0.0, <1.1.0", "passlib>=1.6.5, <1.8", "protobuf>=3.14,<3.15", - "pyasn1>=0.3.1,<0.5", "pyOpenSSL>=20.0,<20.1", "pyparsing>=2.4.2,<2.5", "pyperclip>=1.6.0,<1.9", diff --git a/test/helper_tools/memoryleak2.py b/test/helper_tools/memoryleak2.py new file mode 100644 index 000000000..26fa742d9 --- /dev/null +++ b/test/helper_tools/memoryleak2.py @@ -0,0 +1,21 @@ +import secrets +from pathlib import Path + +import objgraph + +from mitmproxy import certs + +if __name__ == "__main__": + store = certs.CertStore.from_store(path=Path("~/.mitmproxy/").expanduser(), basename="mitmproxy", key_size=2048) + store.STORE_CAP = 5 + + for _ in range(5): + store.get_cert(commonname=secrets.token_hex(16).encode(), sans=[], organization=None) + + objgraph.show_growth() + + for _ in range(20): + store.get_cert(commonname=secrets.token_hex(16).encode(), sans=[], organization=None) + + print("====") + objgraph.show_growth() diff --git a/test/mitmproxy/addons/test_tlsconfig.py b/test/mitmproxy/addons/test_tlsconfig.py index 3931aba24..26d0a491b 100644 --- a/test/mitmproxy/addons/test_tlsconfig.py +++ b/test/mitmproxy/addons/test_tlsconfig.py @@ -57,21 +57,21 @@ class TestTlsConfig: ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options) # Edge case first: We don't have _any_ idea about the server, so we just return "mitmproxy" as subject. - cert, pkey, chainfile = ta.get_cert(ctx) - assert cert.cn == b"mitmproxy" + entry = ta.get_cert(ctx) + assert entry.cert.cn == "mitmproxy" # Here we have an existing server connection... ctx.server.address = ("server-address.example", 443) with open(tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.crt"), "rb") as f: ctx.server.certificate_list = [certs.Cert.from_pem(f.read())] - cert, pkey, chainfile = ta.get_cert(ctx) - assert cert.cn == b"example.mitmproxy.org" - assert cert.altnames == [b"example.mitmproxy.org", b"server-address.example"] + entry = ta.get_cert(ctx) + assert entry.cert.cn == "example.mitmproxy.org" + assert entry.cert.altnames == ["example.mitmproxy.org", "server-address.example"] # And now we also incorporate SNI. - ctx.client.sni = b"sni.example" - cert, pkey, chainfile = ta.get_cert(ctx) - assert cert.altnames == [b"example.mitmproxy.org", b"sni.example"] + ctx.client.sni = "sni.example" + entry = ta.get_cert(ctx) + assert entry.cert.altnames == ["example.mitmproxy.org", "sni.example"] def test_tls_clienthello(self): # only really testing for coverage here, there's no point in mirroring the individual conditions @@ -222,7 +222,7 @@ class TestTlsConfig: @pytest.mark.asyncio async def test_ca_expired(self, monkeypatch): - monkeypatch.setattr(SSL.X509, "has_expired", lambda self: True) + monkeypatch.setattr(certs.Cert, "has_expired", lambda self: True) ta = tlsconfig.TlsConfig() with taddons.context(ta) as tctx: ta.configure(["confdir"]) diff --git a/test/mitmproxy/data/confdir/mitmproxy-dhparam.pem b/test/mitmproxy/data/confdir/mitmproxy-dhparam.pem new file mode 100644 index 000000000..c10121fbf --- /dev/null +++ b/test/mitmproxy/data/confdir/mitmproxy-dhparam.pem @@ -0,0 +1,14 @@ + +-----BEGIN DH PARAMETERS----- +MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3 +O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv +j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ +Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB +chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC +ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq +o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX +IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv +A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8 +6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I +rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= +-----END DH PARAMETERS----- diff --git a/test/mitmproxy/net/data/dsa_cert.pem b/test/mitmproxy/net/data/dsa_cert.pem new file mode 100644 index 000000000..b0230435a --- /dev/null +++ b/test/mitmproxy/net/data/dsa_cert.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDHTCCAtugAwIBAgIJAJNwd38WrLI/MAsGCWCGSAFlAwQDAjBFMQswCQYDVQQG +EwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lk +Z2l0cyBQdHkgTHRkMB4XDTIwMTIzMDE5MzkxMloXDTIxMDEyOTE5MzkxMlowRTEL +MAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVy +bmV0IFdpZGdpdHMgUHR5IEx0ZDCCAbcwggEsBgcqhkjOOAQBMIIBHwKBgQDL8YvJ +wqiJEdWKh7KBbgyCxpWhZdxyLFMguXwI2iX3+n8m2vg/6kBLK/sbI3y3qTeQOvQ0 +V/WchWWAnGHT9CwB1MNhqMWd409ZAo49/m6IPYvoB1RctKyUE9E8jUad+jM2WGm9 +zCG+KPeYFebTsWiQhhIgx8vPM657x6wA5y+omQIVAP4uMtD8Xv+TZCeXTB4j2bi0 +yxjTAoGBAIRxOxIma4B4sEToblQjbwQh9gnWEhvvwAefu3Gcav12mRgYBVqPeskJ +0zROLFr9ubGZo0F/xBGCpaVjY+nDbE8/MwURLBfBr2wQTDbXL/Z5Ea1wNIURsqdo +VxXKADqPh/OfylObCONxF37RnSiABwlCPpIERj+g90oNSFLnvftiA4GEAAKBgC6X +5sBsCV71iCPZSM1UCiSOnzCg/6w5nHPawwcdPUtU6/mpOdylfHlttr4jA1c2a3TB +9ppxe1H0945nYFCHmU5zGtgwKPlNMS5GLrHtSbxiwB5dUm5AsNsLA0EVLF1Po7sI +xQRWTWp57iLzRGdDoKFiPaW07g/UnNzd6i/RAfcWo1MwUTAdBgNVHQ4EFgQUtX4q +D2f00P8e4zp9L+zOib+Fv10wHwYDVR0jBBgwFoAUtX4qD2f00P8e4zp9L+zOib+F +v10wDwYDVR0TAQH/BAUwAwEB/zALBglghkgBZQMEAwIDLwAwLAIURQwecWi0wMjS +EsmC5hxASeuH2UkCFG6zZfM0Sbms27vcrCfVEtxzhizU +-----END CERTIFICATE----- diff --git a/test/mitmproxy/net/data/ec_cert.pem b/test/mitmproxy/net/data/ec_cert.pem new file mode 100644 index 000000000..8103b4fd1 --- /dev/null +++ b/test/mitmproxy/net/data/ec_cert.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIB1DCCAXqgAwIBAgIJALAxolM7r60uMAoGCCqGSM49BAMCMEUxCzAJBgNVBAYT +AkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn +aXRzIFB0eSBMdGQwHhcNMjAxMjMwMTkzOTIyWhcNMjEwMTI5MTkzOTIyWjBFMQsw +CQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJu +ZXQgV2lkZ2l0cyBQdHkgTHRkMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbfF4 +pQnC+cJvoOyp8g9scQsJLNBsd4Vd0ADCyLBx84B1WucmBfpqi+8ERo8e8Y899UKv +KWov2eqOUmGiCLS3kaNTMFEwHQYDVR0OBBYEFHRF3WQJ6x05/c29Rv0mPZn6CwKb +MB8GA1UdIwQYMBaAFHRF3WQJ6x05/c29Rv0mPZn6CwKbMA8GA1UdEwEB/wQFMAMB +Af8wCgYIKoZIzj0EAwIDSAAwRQIgVDA9XrBWlQQteO0LHlkXVpZH0QXjtnQi426Y +ZxGgG18CIQCVSh3iWEAsCXWWUwiVmqEbpWGNIwwpbK4wL7+5lvUvXw== +-----END CERTIFICATE----- diff --git a/test/mitmproxy/net/test_tls.py b/test/mitmproxy/net/test_tls.py index c12b45128..8833743db 100644 --- a/test/mitmproxy/net/test_tls.py +++ b/test/mitmproxy/net/test_tls.py @@ -29,7 +29,7 @@ def test_sslkeylogfile(tdata, monkeypatch): Path(tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.pem")), Path(tdata.path("mitmproxy/net/data/dhparam.pem")) ) - cert, key, chain_file = store.get_cert(b"example.com", [], None) + entry = store.get_cert("example.com", [], None) cctx = tls.create_proxy_server_context( min_version=tls.DEFAULT_MIN_VERSION, @@ -46,9 +46,9 @@ def test_sslkeylogfile(tdata, monkeypatch): min_version=tls.DEFAULT_MIN_VERSION, max_version=tls.DEFAULT_MAX_VERSION, cipher_list=None, - cert=cert, - key=key, - chain_file=chain_file, + cert=entry.cert, + key=entry.privatekey, + chain_file=entry.chain_file, alpn_select_callback=None, request_client_cert=False, extra_chain_certs=(), @@ -105,7 +105,7 @@ class TestClientHello: ) c = tls.ClientHello(data) assert repr(c) - assert c.sni == b'example.com' + assert c.sni == 'example.com' assert c.cipher_suites == [ 49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161, 49171, 49162, 49172, 156, 157, 47, 53, 10 diff --git a/test/mitmproxy/proxy/layers/test_modes.py b/test/mitmproxy/proxy/layers/test_modes.py index 302fe234d..48894a757 100644 --- a/test/mitmproxy/proxy/layers/test_modes.py +++ b/test/mitmproxy/proxy/layers/test_modes.py @@ -202,7 +202,7 @@ def test_reverse_proxy_tcp_over_tls(tctx: Context, monkeypatch, patch, connectio >> reply_tls_start() << SendData(tctx.server, data) ) - assert tls.parse_client_hello(data()).sni == b"localhost" + assert tls.parse_client_hello(data()).sni == "localhost" @pytest.mark.parametrize("connection_strategy", ["eager", "lazy"]) diff --git a/test/mitmproxy/proxy/layers/test_tls.py b/test/mitmproxy/proxy/layers/test_tls.py index b0219838d..925896a09 100644 --- a/test/mitmproxy/proxy/layers/test_tls.py +++ b/test/mitmproxy/proxy/layers/test_tls.py @@ -76,7 +76,7 @@ def test_get_client_hello(): def test_parse_client_hello(): - assert tls.parse_client_hello(client_hello_with_extensions).sni == b"example.com" + assert tls.parse_client_hello(client_hello_with_extensions).sni == "example.com" assert tls.parse_client_hello(client_hello_with_extensions[:50]) is None with pytest.raises(ValueError): tls.parse_client_hello(client_hello_with_extensions[:183] + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00') @@ -188,7 +188,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut tls_start.ssl_conn = SSL.Connection(ssl_context) tls_start.ssl_conn.set_connect_state() # Set SNI - tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni) + tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode()) # Manually enable hostname verification. # Recent OpenSSL versions provide slightly nicer ways to do this, but they are not exposed in @@ -202,7 +202,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT ) SSL._openssl_assert( - SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni, 0) == 1 + SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode(), 0) == 1 ) return tutils.reply(*args, side_effect=make_conn, **kwargs) @@ -227,7 +227,7 @@ class TestServerTLS: playbook = tutils.Playbook(tls.ServerTLSLayer(tctx)) tctx.server.state = ConnectionState.OPEN tctx.server.address = ("example.mitmproxy.org", 443) - tctx.server.sni = b"example.mitmproxy.org" + tctx.server.sni = "example.mitmproxy.org" tssl = SSLTest(server_side=True) @@ -280,7 +280,7 @@ class TestServerTLS: """If the certificate is not trusted, we should fail.""" playbook = tutils.Playbook(tls.ServerTLSLayer(tctx)) tctx.server.address = ("wrong.host.mitmproxy.org", 443) - tctx.server.sni = b"wrong.host.mitmproxy.org" + tctx.server.sni = "wrong.host.mitmproxy.org" tssl = SSLTest(server_side=True) @@ -316,7 +316,7 @@ class TestServerTLS: def test_remote_speaks_no_tls(self, tctx): playbook = tutils.Playbook(tls.ServerTLSLayer(tctx)) tctx.server.state = ConnectionState.OPEN - tctx.server.sni = b"example.mitmproxy.org" + tctx.server.sni = "example.mitmproxy.org" # send ClientHello, receive random garbage back data = tutils.Placeholder(bytes) @@ -345,7 +345,7 @@ def make_client_tls_layer( # Add some server config, this is needed anyways. tctx.server.address = ("example.mitmproxy.org", 443) - tctx.server.sni = b"example.mitmproxy.org" + tctx.server.sni = "example.mitmproxy.org" tssl_client = SSLTest(**kwargs) # Start handshake. diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index 72d891938..42d3e7098 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -1,7 +1,12 @@ import os +from pathlib import Path + +import pytest + from mitmproxy import certs from ..conftest import skip_windows + # class TestDNTree: # def test_simple(self): # d = certs.DNTree() @@ -32,85 +37,87 @@ from ..conftest import skip_windows # assert d.get("com") == "foo" +@pytest.fixture() +def tstore(tdata): + return certs.CertStore.from_store(tdata.path("mitmproxy/data/confdir"), "mitmproxy", 2048) + + class TestCertStore: def test_create_explicit(self, tmpdir): ca = certs.CertStore.from_store(str(tmpdir), "test", 2048) - assert ca.get_cert(b"foo", []) + assert ca.get_cert("foo", []) ca2 = certs.CertStore.from_store(str(tmpdir), "test", 2048) - assert ca2.get_cert(b"foo", []) + assert ca2.get_cert("foo", []) - assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() + assert ca.default_ca.serial == ca2.default_ca.serial - def test_create_no_common_name(self, tmpdir): - ca = certs.CertStore.from_store(str(tmpdir), "test", 2048) - assert ca.get_cert(None, [])[0].cn is None + def test_create_no_common_name(self, tstore): + assert tstore.get_cert(None, []).cert.cn is None - def test_create_tmp(self, tmpdir): - ca = certs.CertStore.from_store(str(tmpdir), "test", 2048) - assert ca.get_cert(b"foo.com", []) - assert ca.get_cert(b"foo.com", []) - assert ca.get_cert(b"*.foo.com", []) + def test_chain_file(self, tdata, tmp_path): + cert = Path(tdata.path("mitmproxy/data/confdir/mitmproxy-ca.pem")).read_bytes() + (tmp_path / "mitmproxy-ca.pem").write_bytes(cert) + ca = certs.CertStore.from_store(tmp_path, "mitmproxy", 2048) + assert ca.default_chain_file is None - r = ca.get_cert(b"*.foo.com", []) - assert r[1] == ca.default_privatekey + (tmp_path / "mitmproxy-ca.pem").write_bytes(2 * cert) + ca = certs.CertStore.from_store(tmp_path, "mitmproxy", 2048) + assert ca.default_chain_file == (tmp_path / "mitmproxy-ca.pem") - def test_sans(self, tmpdir): - ca = certs.CertStore.from_store(str(tmpdir), "test", 2048) - c1 = ca.get_cert(b"foo.com", [b"*.bar.com"]) - ca.get_cert(b"foo.bar.com", []) + def test_sans(self, tstore): + c1 = tstore.get_cert("foo.com", ["*.bar.com"]) + tstore.get_cert("foo.bar.com", []) # assert c1 == c2 - c3 = ca.get_cert(b"bar.com", []) + c3 = tstore.get_cert("bar.com", []) assert not c1 == c3 - def test_sans_change(self, tmpdir): - ca = certs.CertStore.from_store(str(tmpdir), "test", 2048) - ca.get_cert(b"foo.com", [b"*.bar.com"]) - cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"]) - assert b"*.baz.com" in cert.altnames + def test_sans_change(self, tstore): + tstore.get_cert("foo.com", ["*.bar.com"]) + entry = tstore.get_cert("foo.bar.com", ["*.baz.com"]) + assert "*.baz.com" in entry.cert.altnames - def test_expire(self, tmpdir): - ca = certs.CertStore.from_store(str(tmpdir), "test", 2048) - ca.STORE_CAP = 3 - ca.get_cert(b"one.com", []) - ca.get_cert(b"two.com", []) - ca.get_cert(b"three.com", []) + def test_expire(self, tstore): + tstore.STORE_CAP = 3 + tstore.get_cert("one.com", []) + tstore.get_cert("two.com", []) + tstore.get_cert("three.com", []) - assert (b"one.com", ()) in ca.certs - assert (b"two.com", ()) in ca.certs - assert (b"three.com", ()) in ca.certs + assert ("one.com", ()) in tstore.certs + assert ("two.com", ()) in tstore.certs + assert ("three.com", ()) in tstore.certs - ca.get_cert(b"one.com", []) + tstore.get_cert("one.com", []) - assert (b"one.com", ()) in ca.certs - assert (b"two.com", ()) in ca.certs - assert (b"three.com", ()) in ca.certs + assert ("one.com", ()) in tstore.certs + assert ("two.com", ()) in tstore.certs + assert ("three.com", ()) in tstore.certs - ca.get_cert(b"four.com", []) + tstore.get_cert("four.com", []) - assert (b"one.com", ()) not in ca.certs - assert (b"two.com", ()) in ca.certs - assert (b"three.com", ()) in ca.certs - assert (b"four.com", ()) in ca.certs + assert ("one.com", ()) not in tstore.certs + assert ("two.com", ()) in tstore.certs + assert ("three.com", ()) in tstore.certs + assert ("four.com", ()) in tstore.certs - def test_overrides(self, tmpdir): - ca1 = certs.CertStore.from_store(str(tmpdir.join("ca1")), "test", 2048) - ca2 = certs.CertStore.from_store(str(tmpdir.join("ca2")), "test", 2048) - assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number() + def test_overrides(self, tmp_path): + ca1 = certs.CertStore.from_store(tmp_path / "ca1", "test", 2048) + ca2 = certs.CertStore.from_store(tmp_path / "ca2", "test", 2048) + assert not ca1.default_ca.serial == ca2.default_ca.serial - dc = ca2.get_cert(b"foo.com", [b"sans.example.com"]) - dcp = tmpdir.join("dc") - dcp.write(dc[0].to_pem()) - ca1.add_cert_file("foo.com", str(dcp)) + dc = ca2.get_cert("foo.com", ["sans.example.com"]) + dcp = tmp_path / "dc" + dcp.write_bytes(dc.cert.to_pem()) + ca1.add_cert_file("foo.com", dcp) - ret = ca1.get_cert(b"foo.com", []) - assert ret[0].serial == dc[0].serial + ret = ca1.get_cert("foo.com", []) + assert ret.cert.serial == dc.cert.serial - def test_create_dhparams(self, tmpdir): - filename = str(tmpdir.join("dhparam.pem")) + def test_create_dhparams(self, tmp_path): + filename = tmp_path / "dhparam.pem" certs.CertStore.load_dhparam(filename) - assert os.path.exists(filename) + assert filename.exists() @skip_windows def test_umask_secret(self, tmpdir): @@ -123,22 +130,21 @@ class TestCertStore: class TestDummyCert: - def test_with_ca(self, tmpdir): - ca = certs.CertStore.from_store(str(tmpdir), "test", 2048) + def test_with_ca(self, tstore): r = certs.dummy_cert( - ca.default_privatekey, - ca.default_ca, - b"foo.com", - [b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"], - b"Foo Ltd." + tstore.default_privatekey, + tstore.default_ca._cert, + "foo.com", + ["one.com", "two.com", "*.three.com", "127.0.0.1"], + "Foo Ltd." ) - assert r.cn == b"foo.com" - assert r.altnames == [b'one.com', b'two.com', b'*.three.com'] - assert r.organization == b"Foo Ltd." + assert r.cn == "foo.com" + assert r.altnames == ["one.com", "two.com", "*.three.com", "127.0.0.1"] + assert r.organization == "Foo Ltd." r = certs.dummy_cert( - ca.default_privatekey, - ca.default_ca, + tstore.default_privatekey, + tstore.default_ca._cert, None, [], None @@ -154,16 +160,16 @@ class TestCert: with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f: d = f.read() c1 = certs.Cert.from_pem(d) - assert c1.cn == b"google.com" + assert c1.cn == "google.com" assert len(c1.altnames) == 436 - assert c1.organization == b"Google Inc" + assert c1.organization == "Google Inc" with open(tdata.path("mitmproxy/net/data/text_cert_2"), "rb") as f: d = f.read() c2 = certs.Cert.from_pem(d) - assert c2.cn == b"www.inode.co.nz" + assert c2.cn == "www.inode.co.nz" assert len(c2.altnames) == 2 - assert c2.digest("sha1") + assert c2.fingerprint() assert c2.notbefore assert c2.notafter assert c2.subject @@ -171,10 +177,30 @@ class TestCert: assert c2.serial assert c2.issuer assert c2.to_pem() - assert c2.has_expired is not None + assert c2.has_expired() is not None assert c1 != c2 + def test_convert(self, tdata): + with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f: + d = f.read() + c = certs.Cert.from_pem(d) + + assert c == certs.Cert.from_pem(c.to_pem()) + assert c == certs.Cert.from_state(c.get_state()) + assert c == certs.Cert.from_pyopenssl(c.to_pyopenssl()) + + @pytest.mark.parametrize("filename,name,bits", [ + ("text_cert", "RSA", 1024), + ("dsa_cert.pem", "DSA", 1024), + ("ec_cert.pem", "EC (secp256r1)", 256), + ]) + def test_keyinfo(self, tdata, filename, name, bits): + with open(tdata.path(f"mitmproxy/net/data/{filename}"), "rb") as f: + d = f.read() + c = certs.Cert.from_pem(d) + assert c.keyinfo == (name, bits) + def test_err_broken_sans(self, tdata): with open(tdata.path("mitmproxy/net/data/text_cert_weird1"), "rb") as f: d = f.read() @@ -182,12 +208,6 @@ class TestCert: # This breaks unless we ignore a decoding error. assert c.altnames is not None - def test_der(self, tdata): - with open(tdata.path("mitmproxy/net/data/dercert"), "rb") as f: - d = f.read() - s = certs.Cert.from_der(d) - assert s.cn - def test_state(self, tdata): with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f: d = f.read() @@ -201,12 +221,13 @@ class TestCert: assert c == c2 assert c is not c2 - x = certs.Cert('') - x.set_state(a) - assert x == c + c2.set_state(a) + assert c == c2 - def test_from_store_with_passphrase(self, tdata, tmpdir): - ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, b"password") - ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), b"password") + def test_from_store_with_passphrase(self, tdata, tstore): + tstore.add_cert_file("unencrypted-no-pass", Path(tdata.path("mitmproxy/data/testkey.pem")), None) + tstore.add_cert_file("unencrypted-pass", Path(tdata.path("mitmproxy/data/testkey.pem")), b"password") + tstore.add_cert_file("encrypted-pass", Path(tdata.path("mitmproxy/data/mitmproxy.pem")), b"password") - assert ca.get_cert(b"foo", []) + with pytest.raises(TypeError): + tstore.add_cert_file("encrypted-no-pass", Path(tdata.path("mitmproxy/data/mitmproxy.pem")), None) diff --git a/test/mitmproxy/tools/console/test_common.py b/test/mitmproxy/tools/console/test_common.py index 1f59ac4e8..0ae897d2e 100644 --- a/test/mitmproxy/tools/console/test_common.py +++ b/test/mitmproxy/tools/console/test_common.py @@ -25,12 +25,10 @@ def test_format_keyvals(): ("ee", None), ] ) - wrapped = urwid.BoxAdapter( - urwid.ListBox( - urwid.SimpleFocusListWalker( - common.format_keyvals([("foo", "bar")]) - ) - ), 1 + wrapped = urwid.Pile( + urwid.SimpleFocusListWalker( + common.format_keyvals([("foo", "bar")]) + ) ) assert wrapped.render((30,)) assert common.format_keyvals(