Merge pull request #4374 from mhils/cryptography-certs

Use cryptography for certificate generation
This commit is contained in:
Maximilian Hils 2020-12-30 22:57:00 +01:00 committed by GitHub
commit 7d67eefe29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 615 additions and 548 deletions

View File

@ -38,6 +38,7 @@ If you depend on these features, please raise your voice in
### Full Changelog
* New Proxy Core based on sans-io pattern (@mhils)
* Use pyca/cryptography to generate certificates, not pyOpenSSL (@mhils)
* Remove the legacy protocol stack (@Kriechi)
* Remove all deprecated pathod and pathoc tools and modules (@Kriechi)
* --- TODO: add new PRs above this line ---

View File

@ -66,7 +66,7 @@ class NextLayer:
pass
else:
if sni:
hostnames.append(sni.decode("idna"))
hostnames.append(sni)
if not hostnames:
return False

View File

@ -1,7 +1,8 @@
import os
from typing import List, Optional, Tuple, TypedDict, Any
from pathlib import Path
from typing import List, Optional, TypedDict, Any
from OpenSSL import SSL, crypto
from OpenSSL import SSL
from mitmproxy import certs, ctx, exceptions
from mitmproxy.net import tls as net_tls
from mitmproxy.options import CONF_BASENAME
@ -115,7 +116,7 @@ class TlsConfig:
client: context.Client = tls_start.context.client
server: context.Server = tls_start.context.server
cert, key, chain_file = self.get_cert(tls_start.context)
entry = self.get_cert(tls_start.context)
if not client.cipher_list and ctx.options.ciphers_client:
client.cipher_list = ctx.options.ciphers_client.split(":")
@ -126,9 +127,9 @@ class TlsConfig:
min_version=net_tls.Version[ctx.options.tls_version_client_min],
max_version=net_tls.Version[ctx.options.tls_version_client_max],
cipher_list=cipher_list,
cert=cert,
key=key,
chain_file=chain_file,
cert=entry.cert,
key=entry.privatekey,
chain_file=entry.chain_file,
request_client_cert=False,
alpn_select_callback=alpn_select_callback,
extra_chain_certs=server.certificate_list,
@ -152,8 +153,7 @@ class TlsConfig:
verify = net_tls.Verify.VERIFY_PEER
if server.sni is True:
server.sni = client.sni or server.address[0].encode()
sni = server.sni or None # make sure that false-y values are None
server.sni = client.sni or server.address[0]
if not server.alpn_offers:
if client.alpn_offers:
@ -182,7 +182,7 @@ class TlsConfig:
if os.path.isfile(client_certs):
client_cert = client_certs
else:
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
server_name: str = server.sni or server.address[0]
p = os.path.join(client_certs, f"{server_name}.pem")
if os.path.isfile(p):
client_cert = p
@ -192,7 +192,7 @@ class TlsConfig:
max_version=net_tls.Version[ctx.options.tls_version_client_max],
cipher_list=cipher_list,
verify=verify,
sni=sni,
sni=server.sni,
ca_path=ctx.options.ssl_verify_upstream_trusted_confdir,
ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca,
client_cert=client_cert,
@ -200,7 +200,8 @@ class TlsConfig:
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
tls_start.ssl_conn.set_tlsext_host_name(server.sni)
if server.sni:
tls_start.ssl_conn.set_tlsext_host_name(server.sni.encode())
tls_start.ssl_conn.set_connect_state()
def running(self):
@ -232,8 +233,8 @@ class TlsConfig:
if len(parts) == 1:
parts = ["*", parts[0]]
cert = os.path.expanduser(parts[1])
if not os.path.exists(cert):
cert = Path(parts[1]).expanduser()
if not cert.exists():
raise exceptions.OptionsError(f"Certificate file does not exist: {cert}")
try:
self.certstore.add_cert_file(
@ -241,16 +242,16 @@ class TlsConfig:
cert,
passphrase=ctx.options.cert_passphrase.encode("utf8") if ctx.options.cert_passphrase else None,
)
except crypto.Error as e:
raise exceptions.OptionsError(f"Invalid certificate format: {cert}") from e
except ValueError as e:
raise exceptions.OptionsError(f"Invalid certificate format for {cert}: {e}") from e
def get_cert(self, conn_context: context.Context) -> Tuple[certs.Cert, SSL.PKey, str]:
def get_cert(self, conn_context: context.Context) -> certs.CertStoreEntry:
"""
This function determines the Common Name (CN), Subject Alternative Names (SANs) and Organization Name
our certificate should have and then fetches a matching cert from the certstore.
"""
altnames: List[bytes] = []
organization: Optional[bytes] = None
altnames: List[str] = []
organization: Optional[str] = None
# Use upstream certificate if available.
if conn_context.server.certificate_list:
@ -265,11 +266,11 @@ class TlsConfig:
if conn_context.client.sni:
altnames.append(conn_context.client.sni)
elif conn_context.server.address:
altnames.append(conn_context.server.address[0].encode("idna"))
altnames.append(conn_context.server.address[0])
# As a last resort, add *something* so that we have a certificate to serve.
if not altnames:
altnames.append(b"mitmproxy")
altnames.append("mitmproxy")
# only keep first occurrence of each hostname
altnames = list(dict.fromkeys(altnames))

View File

@ -1,23 +1,24 @@
import os
import ssl
import time
import contextlib
import datetime
import ipaddress
import os
import sys
import contextlib
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple, Optional, Union, Dict, List
from typing import Tuple, Optional, Union, Dict, List, NewType
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, dsa, ec
from cryptography.hazmat.primitives.serialization import pkcs12
from cryptography.x509 import NameOID, ExtendedKeyUsageOID
from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error
import OpenSSL
from mitmproxy.coretypes import serializable
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
DEFAULT_EXP = 94608000 # = 60 * 60 * 24 * 365 * 3 = 3 years
DEFAULT_EXP_DUMMY_CERT = 31536000 # = 60 * 60 * 24 * 365 = 1 year
CA_EXPIRY = datetime.timedelta(days=3 * 365)
CERT_EXPIRY = datetime.timedelta(days=365)
# Generated with "openssl dhparam". It's too slow to generate this on startup.
DEFAULT_DHPARAM = b"""
@ -37,51 +38,166 @@ rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI=
"""
def create_ca(organization: str, cn: str, exp: int, key_size: int) -> Tuple[OpenSSL.crypto.PKey, OpenSSL.crypto.X509]:
key = OpenSSL.crypto.PKey()
key.generate_key(OpenSSL.crypto.TYPE_RSA, key_size)
cert = OpenSSL.crypto.X509()
cert.set_serial_number(int(time.time() * 10000))
cert.set_version(2)
cert.get_subject().CN = cn
cert.get_subject().O = organization
cert.gmtime_adj_notBefore(-3600 * 48)
cert.gmtime_adj_notAfter(exp)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(key)
cert.add_extensions([
OpenSSL.crypto.X509Extension(
b"basicConstraints",
True,
b"CA:TRUE"
),
OpenSSL.crypto.X509Extension(
b"nsCertType",
False,
b"sslCA"
),
OpenSSL.crypto.X509Extension(
b"extendedKeyUsage",
False,
b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC"
),
OpenSSL.crypto.X509Extension(
b"keyUsage",
True,
b"keyCertSign, cRLSign"
),
OpenSSL.crypto.X509Extension(
b"subjectKeyIdentifier",
False,
b"hash",
subject=cert
),
class Cert(serializable.Serializable):
_cert: x509.Certificate
def __init__(self, cert: x509.Certificate):
assert isinstance(cert, x509.Certificate)
self._cert = cert
def __eq__(self, other):
return self.fingerprint() == other.fingerprint()
@classmethod
def from_state(cls, state):
return cls.from_pem(state)
def get_state(self):
return self.to_pem()
def set_state(self, state):
self._cert = x509.load_pem_x509_certificate(state)
@classmethod
def from_pem(cls, data: bytes) -> "Cert":
cert = x509.load_pem_x509_certificate(data) # type: ignore
return cls(cert)
def to_pem(self) -> bytes:
return self._cert.public_bytes(serialization.Encoding.PEM)
@classmethod
def from_pyopenssl(self, x509: OpenSSL.crypto.X509) -> "Cert":
return Cert(x509.to_cryptography())
def to_pyopenssl(self) -> OpenSSL.crypto.X509:
return OpenSSL.crypto.X509.from_cryptography(self._cert)
def fingerprint(self) -> bytes:
return self._cert.fingerprint(hashes.SHA256())
@property
def issuer(self) -> List[Tuple[str, str]]:
return _name_to_keyval(self._cert.issuer)
@property
def notbefore(self) -> datetime.datetime:
return self._cert.not_valid_before
@property
def notafter(self) -> datetime.datetime:
return self._cert.not_valid_after
def has_expired(self) -> bool:
return datetime.datetime.utcnow() > self._cert.not_valid_after
@property
def subject(self) -> List[Tuple[str, str]]:
return _name_to_keyval(self._cert.subject)
@property
def serial(self) -> int:
return self._cert.serial_number
@property
def keyinfo(self):
public_key = self._cert.public_key()
if isinstance(public_key, rsa.RSAPublicKey):
return "RSA", public_key.key_size
if isinstance(public_key, dsa.DSAPublicKey):
return "DSA", public_key.key_size
if isinstance(public_key, ec.EllipticCurvePublicKey):
return f"EC ({public_key.curve.name})", public_key.key_size
return (public_key.__class__.__name__.replace("PublicKey", "").replace("_", ""),
getattr(public_key, "key_size", -1)) # pragma: no cover
@property
def cn(self) -> Optional[str]:
attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
if attrs:
return attrs[0].value
return None
@property
def organization(self) -> Optional[str]:
attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.ORGANIZATION_NAME)
if attrs:
return attrs[0].value
return None
@property
def altnames(self) -> List[str]:
"""
Get all SubjectAlternativeName DNS altnames.
"""
try:
ext = self._cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value
except x509.ExtensionNotFound:
return []
else:
return (
ext.get_values_for_type(x509.DNSName)
+
[str(x) for x in ext.get_values_for_type(x509.IPAddress)]
)
def _name_to_keyval(name: x509.Name) -> List[Tuple[str, str]]:
parts = []
for rdn in name.rdns:
k, v = rdn.rfc4514_string().split("=", maxsplit=1)
parts.append((k, v))
return parts
def create_ca(
organization: str,
cn: str,
key_size: int,
) -> Tuple[rsa.RSAPrivateKeyWithSerialization, x509.Certificate]:
now = datetime.datetime.now()
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=key_size,
) # type: ignore
name = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, cn),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization)
])
cert.sign(key, "sha256")
return key, cert
builder = x509.CertificateBuilder()
builder = builder.serial_number(x509.random_serial_number())
builder = builder.subject_name(name)
builder = builder.not_valid_before(now - datetime.timedelta(days=2))
builder = builder.not_valid_after(now + CA_EXPIRY)
builder = builder.issuer_name(name)
builder = builder.public_key(private_key.public_key())
builder = builder.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
builder = builder.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False)
builder = builder.add_extension(
x509.KeyUsage(
digital_signature=False,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False,
), critical=True)
builder = builder.add_extension(x509.SubjectKeyIdentifier.from_public_key(private_key.public_key()), critical=False)
cert = builder.sign(private_key=private_key, algorithm=hashes.SHA256()) # type: ignore
return private_key, cert
def dummy_cert(privkey, cacert, commonname, sans, organization):
def dummy_cert(
privkey: rsa.RSAPrivateKey,
cacert: x509.Certificate,
commonname: Optional[str],
sans: List[str],
organization: Optional[str] = None,
) -> Cert:
"""
Generates a dummy certificate.
@ -93,111 +209,114 @@ def dummy_cert(privkey, cacert, commonname, sans, organization):
Returns cert if operation succeeded, None if not.
"""
ss = []
for i in sans:
try:
ipaddress.ip_address(i.decode("ascii"))
except ValueError:
ss.append(b"DNS:%s" % i)
else:
ss.append(b"IP:%s" % i)
ss = b", ".join(ss)
builder = x509.CertificateBuilder()
builder = builder.issuer_name(cacert.subject)
builder = builder.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False)
builder = builder.public_key(cacert.public_key())
cert = OpenSSL.crypto.X509()
cert.gmtime_adj_notBefore(-3600 * 48)
cert.gmtime_adj_notAfter(DEFAULT_EXP_DUMMY_CERT)
cert.set_issuer(cacert.get_subject())
now = datetime.datetime.now()
builder = builder.not_valid_before(now - datetime.timedelta(days=2))
builder = builder.not_valid_after(now + CERT_EXPIRY)
subject = []
is_valid_commonname = (
commonname is not None and len(commonname) < 64
)
if is_valid_commonname:
cert.get_subject().CN = commonname
assert commonname is not None
subject.append(x509.NameAttribute(NameOID.COMMON_NAME, commonname))
if organization is not None:
cert.get_subject().O = organization
cert.set_serial_number(int(time.time() * 10000))
if ss:
cert.set_version(2)
cert.add_extensions(
[OpenSSL.crypto.X509Extension(
b"subjectAltName",
assert organization is not None
subject.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization))
builder = builder.subject_name(x509.Name(subject))
builder = builder.serial_number(x509.random_serial_number())
ss: List[x509.GeneralName] = []
for x in sans:
try:
ip = ipaddress.ip_address(x)
except ValueError:
ss.append(x509.DNSName(x))
else:
ss.append(x509.IPAddress(ip))
# RFC 5280 §4.2.1.6: subjectAltName is critical if subject is empty.
not is_valid_commonname,
ss
)]
)
cert.add_extensions([
OpenSSL.crypto.X509Extension(
b"extendedKeyUsage",
False,
b"serverAuth,clientAuth"
)
])
cert.set_pubkey(cacert.get_pubkey())
cert.sign(privkey, "sha256")
builder = builder.add_extension(x509.SubjectAlternativeName(ss), critical=not is_valid_commonname)
cert = builder.sign(private_key=privkey, algorithm=hashes.SHA256()) # type: ignore
return Cert(cert)
@dataclass(frozen=True)
class CertStoreEntry:
def __init__(self, cert, privatekey, chain_file):
self.cert = cert
self.privatekey = privatekey
self.chain_file = chain_file
cert: Cert
privatekey: rsa.RSAPrivateKey
chain_file: Optional[Path]
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = Tuple[Optional[bytes], Tuple[bytes, ...]] # (common_name, sans)
TCustomCertId = str # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = Tuple[Optional[str], Tuple[str, ...]] # (common_name, sans)
TCertId = Union[TCustomCertId, TGeneratedCertId]
DHParams = NewType("DHParams", bytes)
class CertStore:
"""
Implements an in-memory certificate store.
"""
STORE_CAP = 100
certs: Dict[TCertId, CertStoreEntry]
expire_queue: List[CertStoreEntry]
def __init__(
self,
default_privatekey,
default_ca,
default_chain_file,
dhparams):
default_privatekey: rsa.RSAPrivateKey,
default_ca: Cert,
default_chain_file: Optional[Path],
dhparams: DHParams
):
self.default_privatekey = default_privatekey
self.default_ca = default_ca
self.default_chain_file = default_chain_file
self.dhparams = dhparams
self.certs: Dict[TCertId, CertStoreEntry] = {}
self.certs = {}
self.expire_queue = []
def expire(self, entry):
def expire(self, entry: CertStoreEntry) -> None:
self.expire_queue.append(entry)
if len(self.expire_queue) > self.STORE_CAP:
d = self.expire_queue.pop(0)
self.certs = {k: v for k, v in self.certs.items() if v != d}
@staticmethod
def load_dhparam(path: str):
def load_dhparam(path: Path) -> DHParams:
# mitmproxy<=0.10 doesn't generate a dhparam file.
# Create it now if necessary.
if not os.path.exists(path):
with open(path, "wb") as f:
f.write(DEFAULT_DHPARAM)
if not path.exists():
path.write_bytes(DEFAULT_DHPARAM)
bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r")
# we could use cryptography for this, but it's unclear how to convert cryptography's object to pyOpenSSL's
# expected format.
bio = OpenSSL.SSL._lib.BIO_new_file(str(path).encode(sys.getfilesystemencoding()), b"r")
if bio != OpenSSL.SSL._ffi.NULL:
bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free)
dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams(
bio,
OpenSSL.SSL._ffi.NULL,
OpenSSL.SSL._ffi.NULL,
OpenSSL.SSL._ffi.NULL)
OpenSSL.SSL._ffi.NULL
)
dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free)
return dh
raise RuntimeError("Error loading DH Params.") # pragma: no cover
@classmethod
def from_store(cls, path: Union[Path, str], basename: str, key_size, passphrase: Optional[bytes] = None) -> "CertStore":
def from_store(
cls,
path: Union[Path, str],
basename: str,
key_size: int,
passphrase: Optional[bytes] = None
) -> "CertStore":
path = Path(path)
ca_file = path / f"{basename}-ca.pem"
dhparam_file = path / f"{basename}-dhparam.pem"
@ -208,17 +327,14 @@ class CertStore:
@classmethod
def from_files(cls, ca_file: Path, dhparam_file: Path, passphrase: Optional[bytes] = None) -> "CertStore":
raw = ca_file.read_bytes()
ca = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_PEM,
raw
)
key = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
raw,
passphrase
)
dh = cls.load_dhparam(str(dhparam_file))
return cls(key, ca, str(ca_file), dh)
key = load_pem_private_key(raw, passphrase)
ca = Cert.from_pem(raw)
dh = cls.load_dhparam(dhparam_file)
if raw.count(b"BEGIN CERTIFICATE") != 1:
chain_file: Optional[Path] = ca_file
else:
chain_file = None
return cls(key, ca, chain_file, dh)
@staticmethod
@contextlib.contextmanager
@ -236,74 +352,72 @@ class CertStore:
os.umask(original_umask)
@staticmethod
def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None, expiry=DEFAULT_EXP) -> None:
def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None) -> None:
path.mkdir(parents=True, exist_ok=True)
organization = organization or basename
cn = cn or basename
key, ca = create_ca(organization=organization, cn=cn, exp=expiry, key_size=key_size)
# Dump the CA plus private key
with CertStore.umask_secret(), (path / f"{basename}-ca.pem").open("wb") as f:
f.write(
OpenSSL.crypto.dump_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
key))
f.write(
OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM,
ca))
key: rsa.RSAPrivateKeyWithSerialization
ca: x509.Certificate
key, ca = create_ca(organization=organization, cn=cn, key_size=key_size)
# Dump the certificate in PEM format
with (path / f"{basename}-ca-cert.pem").open("wb") as f:
f.write(
OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM,
ca))
# Create a .cer file with the same contents for Android
with (path / f"{basename}-ca-cert.cer").open("wb") as f:
f.write(
OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM,
ca))
# Dump the certificate in PKCS12 format for Windows devices
with (path / f"{basename}-ca-cert.p12").open("wb") as f:
p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca)
f.write(p12.export())
# Dump the certificate and key in a PKCS12 format for Windows devices
with CertStore.umask_secret(), (path / f"{basename}-ca.p12").open("wb") as f:
p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca)
p12.set_privatekey(key)
f.write(p12.export())
with (path / f"{basename}-dhparam.pem").open("wb") as f:
f.write(DEFAULT_DHPARAM)
def add_cert_file(self, spec: str, path: str, passphrase: Optional[bytes] = None) -> None:
with open(path, "rb") as f:
raw = f.read()
cert = Cert(
OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_PEM,
raw))
try:
privatekey = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
raw,
passphrase)
except Exception:
privatekey = self.default_privatekey
self.add_cert(
CertStoreEntry(cert, privatekey, path),
spec.encode("idna")
# Dump the CA plus private key.
with CertStore.umask_secret():
# PEM format
(path / f"{basename}-ca.pem").write_bytes(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
) +
ca.public_bytes(serialization.Encoding.PEM)
)
def add_cert(self, entry: CertStoreEntry, *names: bytes):
# PKCS12 format for Windows devices
(path / f"{basename}-ca.p12").write_bytes(
pkcs12.serialize_key_and_certificates( # type: ignore
name=basename.encode(),
key=key,
cert=ca,
cas=None,
encryption_algorithm=serialization.NoEncryption(),
)
)
# Dump the certificate in PEM format
pem_cert = ca.public_bytes(serialization.Encoding.PEM)
(path / f"{basename}-ca-cert.pem").write_bytes(pem_cert)
# Create a .cer file with the same contents for Android
(path / f"{basename}-ca-cert.cer").write_bytes(pem_cert)
# Dump the certificate in PKCS12 format for Windows devices
(path / f"{basename}-ca-cert.p12").write_bytes(
pkcs12.serialize_key_and_certificates( # type: ignore
name=basename.encode(),
key=None,
cert=ca,
cas=None,
encryption_algorithm=serialization.NoEncryption(),
)
)
(path / f"{basename}-dhparam.pem").write_bytes(DEFAULT_DHPARAM)
def add_cert_file(self, spec: str, path: Path, passphrase: Optional[bytes] = None) -> None:
raw = path.read_bytes()
cert = Cert.from_pem(raw)
try:
key = load_pem_private_key(raw, password=passphrase)
except ValueError:
key = self.default_privatekey
self.add_cert(
CertStoreEntry(cert, key, path),
spec
)
def add_cert(self, entry: CertStoreEntry, *names: str) -> None:
"""
Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument.
@ -316,26 +430,24 @@ class CertStore:
self.certs[i] = entry
@staticmethod
def asterisk_forms(dn: bytes) -> List[bytes]:
def asterisk_forms(dn: str) -> List[str]:
"""
Return all asterisk forms for a domain. For example, for www.example.com this will return
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
"""
parts = dn.split(b".")
parts = dn.split(".")
ret = [dn]
for i in range(1, len(parts)):
ret.append(b"*." + b".".join(parts[i:]))
ret.append("*." + ".".join(parts[i:]))
return ret
def get_cert(
self,
commonname: Optional[bytes],
sans: List[bytes],
organization: Optional[bytes] = None
) -> Tuple["Cert", OpenSSL.SSL.PKey, str]:
commonname: Optional[str],
sans: List[str],
organization: Optional[str] = None
) -> CertStoreEntry:
"""
Returns an (cert, privkey, cert_chain) tuple.
commonname: Common name for the generated certificate. Must be a
valid, plain-ASCII, IDNA-encoded domain name.
@ -349,7 +461,7 @@ class CertStore:
potential_keys.extend(self.asterisk_forms(commonname))
for s in sans:
potential_keys.extend(self.asterisk_forms(s))
potential_keys.append(b"*")
potential_keys.append("*")
potential_keys.append((commonname, tuple(sans)))
name = next(
@ -362,147 +474,27 @@ class CertStore:
entry = CertStoreEntry(
cert=dummy_cert(
self.default_privatekey,
self.default_ca,
self.default_ca._cert,
commonname,
sans,
organization),
privatekey=self.default_privatekey,
chain_file=self.default_chain_file)
chain_file=self.default_chain_file
)
self.certs[(commonname, tuple(sans))] = entry
self.expire(entry)
return entry.cert, entry.privatekey, entry.chain_file
return entry
class _GeneralName(univ.Choice):
# We only care about dNSName and iPAddress
componentType = namedtype.NamedTypes(
namedtype.NamedType('dNSName', char.IA5String().subtype(
implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2)
)),
namedtype.NamedType('iPAddress', univ.OctetString().subtype(
implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 7)
)),
)
class _GeneralNames(univ.SequenceOf):
componentType = _GeneralName()
sizeSpec = univ.SequenceOf.sizeSpec + \
constraint.ValueSizeConstraint(1, 1024)
class Cert(serializable.Serializable):
def __init__(self, cert):
def load_pem_private_key(data: bytes, password: Optional[bytes]) -> rsa.RSAPrivateKey:
"""
Returns a (common name, [subject alternative names]) tuple.
like cryptography's load_pem_private_key, but silently falls back to not using a password
if the private key is unencrypted.
"""
self.x509 = cert
def __eq__(self, other):
return self.digest("sha256") == other.digest("sha256")
def get_state(self):
return self.to_pem()
def set_state(self, state):
self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state)
@classmethod
def from_state(cls, state):
return cls.from_pem(state)
@classmethod
def from_pem(cls, txt):
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
return cls(x509)
@classmethod
def from_der(cls, der):
pem = ssl.DER_cert_to_PEM_cert(der)
return cls.from_pem(pem)
def to_pem(self):
return OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM,
self.x509)
def digest(self, name):
return self.x509.digest(name)
@property
def issuer(self):
return self.x509.get_issuer().get_components()
@property
def notbefore(self):
t = self.x509.get_notBefore()
return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ")
@property
def notafter(self):
t = self.x509.get_notAfter()
return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ")
@property
def has_expired(self):
return self.x509.has_expired()
@property
def subject(self):
return self.x509.get_subject().get_components()
@property
def serial(self):
return self.x509.get_serial_number()
@property
def keyinfo(self):
pk = self.x509.get_pubkey()
types = {
OpenSSL.crypto.TYPE_RSA: "RSA",
OpenSSL.crypto.TYPE_DSA: "DSA",
}
return (
types.get(pk.type(), "UNKNOWN"),
pk.bits()
)
@property
def cn(self) -> Optional[bytes]:
c = None
for i in self.subject:
if i[0] == b"CN":
c = i[1]
return c
@property
def organization(self) -> Optional[bytes]:
c = None
for i in self.subject:
if i[0] == b"O":
c = i[1]
return c
@property
def altnames(self) -> List[bytes]:
"""
Returns:
All DNS altnames.
"""
# tcp.TCPClient.convert_to_tls assumes that this property only contains DNS altnames for hostname verification.
altnames = []
for i in range(self.x509.get_extension_count()):
ext = self.x509.get_extension(i)
if ext.get_short_name() == b"subjectAltName":
try:
dec = decode(ext.get_data(), asn1Spec=_GeneralNames())
except PyAsn1Error:
continue
for x in dec[0]:
if x[0].hasValue():
e = x[0].asOctets()
altnames.append(e)
return altnames
return serialization.load_pem_private_key(data, password) # type: ignore
except TypeError:
if password is not None:
return load_pem_private_key(data, None)
raise

View File

@ -230,6 +230,23 @@ def convert_9_10(data):
return data
def convert_10_11(data):
data["version"] = 11
def conv_conn(conn):
conn["sni"] = strutils.always_str(conn["sni"], "ascii", "backslashreplace")
conn["alpn"] = conn.pop("alpn_proto_negotiated")
conn["alpn_offers"] = conn["alpn_offers"] or []
conn["cipher_list"] = conn["cipher_list"] or []
conv_conn(data["client_conn"])
conv_conn(data["server_conn"])
if data["server_conn"]["via"]:
conv_conn(data["server_conn"]["via"])
return data
def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
@ -287,6 +304,7 @@ converters = {
7: convert_7_8,
8: convert_8_9,
9: convert_9_10,
10: convert_10_11,
}

View File

@ -6,9 +6,10 @@ from pathlib import Path
from typing import Iterable, Callable, Optional, Tuple, List, Any, BinaryIO
import certifi
from cryptography.hazmat.primitives.asymmetric import rsa
from kaitaistruct import KaitaiStream
from OpenSSL import SSL
from OpenSSL import SSL, crypto
from mitmproxy import certs
from mitmproxy.contrib.kaitaistruct import tls_client_hello
from mitmproxy.net import check
@ -129,7 +130,7 @@ def create_proxy_server_context(
max_version: Version,
cipher_list: Optional[Iterable[str]],
verify: Verify,
sni: Optional[bytes],
sni: Optional[str],
ca_path: Optional[str],
ca_pemfile: Optional[str],
client_cert: Optional[str],
@ -147,6 +148,7 @@ def create_proxy_server_context(
context.set_verify(verify.value, None)
if sni is not None:
assert isinstance(sni, str)
# Manually enable hostname verification on the context object.
# https://wiki.openssl.org/index.php/Hostname_validation
param = SSL._lib.SSL_CTX_get0_param(context._context)
@ -157,7 +159,7 @@ def create_proxy_server_context(
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
)
SSL._openssl_assert(
SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni, 0) == 1
SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni.encode(), 0) == 1
)
if ca_path is None and ca_pemfile is None:
@ -188,12 +190,12 @@ def create_client_proxy_context(
max_version: Version,
cipher_list: Optional[Iterable[str]],
cert: certs.Cert,
key: SSL.PKey,
chain_file: str,
key: rsa.RSAPrivateKey,
chain_file: Optional[Path],
alpn_select_callback: Optional[Callable[[SSL.Connection, List[bytes]], Any]],
request_client_cert: bool,
extra_chain_certs: Iterable[certs.Cert],
dhparams,
dhparams: certs.DHParams,
) -> SSL.Context:
context: SSL.Context = _create_ssl_context(
method=Method.TLS_SERVER_METHOD,
@ -202,10 +204,11 @@ def create_client_proxy_context(
cipher_list=cipher_list,
)
context.use_certificate(cert.x509)
context.use_privatekey(key)
context.use_certificate(cert.to_pyopenssl())
context.use_privatekey(crypto.PKey.from_cryptography_key(key))
if chain_file is not None:
try:
context.load_verify_locations(chain_file, None)
context.load_verify_locations(str(chain_file), None)
except SSL.Error as e:
raise RuntimeError(f"Cannot load certificate chain ({chain_file}).") from e
@ -227,7 +230,7 @@ def create_client_proxy_context(
context.set_verify(Verify.VERIFY_NONE.value, None)
for i in extra_chain_certs:
context.add_extra_chain_cert(i.x509)
context.add_extra_chain_cert(i._cert)
if dhparams:
SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams)
@ -277,7 +280,7 @@ class ClientHello:
return self._client_hello.cipher_suites.cipher_suites
@property
def sni(self) -> Optional[bytes]:
def sni(self) -> Optional[str]:
if self._client_hello.extensions:
for extension in self._client_hello.extensions.extensions:
is_valid_sni_extension = (
@ -287,11 +290,11 @@ class ClientHello:
check.is_valid_host(extension.body.server_names[0].host_name)
)
if is_valid_sni_extension:
return extension.body.server_names[0].host_name
return extension.body.server_names[0].host_name.decode("ascii")
return None
@property
def alpn_protocols(self):
def alpn_protocols(self) -> List[bytes]:
if self._client_hello.extensions:
for extension in self._client_hello.extensions.extensions:
if extension.type == 0x10:

View File

@ -65,7 +65,7 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
cipher_list: Sequence[str] = ()
"""Ciphers accepted by the proxy server on this connection."""
tls_version: Optional[str] = None
sni: Union[bytes, Literal[True], None]
sni: Union[str, Literal[True], None]
timestamp_end: Optional[float] = None
"""Connection end timestamp"""
@ -109,7 +109,8 @@ class Client(Connection):
timestamp_start: float
"""TCP SYN received"""
sni: Union[bytes, None] = None
mitmcert: Optional[certs.Cert] = None
sni: Union[str, None] = None
def __init__(self, peername, sockname, timestamp_start):
self.id = str(uuid.uuid4())
@ -122,10 +123,10 @@ class Client(Connection):
# This means we need to add all new fields to the old implementation.
return {
'address': self.peername,
'alpn_proto_negotiated': self.alpn,
'alpn': self.alpn,
'cipher_name': self.cipher,
'id': self.id,
'mitmcert': None,
'mitmcert': self.mitmcert.get_state() if self.mitmcert is not None else None,
'sni': self.sni,
'timestamp_end': self.timestamp_end,
'timestamp_start': self.timestamp_start,
@ -155,7 +156,7 @@ class Client(Connection):
def set_state(self, state):
self.peername = tuple(state["address"]) if state["address"] else None
self.alpn = state["alpn_proto_negotiated"]
self.alpn = state["alpn"]
self.cipher = state["cipher_name"]
self.id = state["id"]
self.sni = state["sni"]
@ -169,6 +170,7 @@ class Client(Connection):
self.error = state["error"]
self.tls = state["tls"]
self.certificate_list = [certs.Cert.from_state(x) for x in state["certificate_list"]]
self.mitmcert = certs.Cert.from_state(state["mitmcert"]) if state["mitmcert"] is not None else None
self.alpn_offers = state["alpn_offers"]
self.cipher_list = state["cipher_list"]
@ -215,7 +217,7 @@ class Server(Connection):
timestamp_tcp_setup: Optional[float] = None
"""TCP ACK received"""
sni: Union[bytes, Literal[True], None] = True
sni: Union[str, Literal[True], None] = True
"""True: client SNI, False: no SNI, bytes: custom value"""
via: Optional[server_spec.ServerSpec] = None
@ -226,7 +228,7 @@ class Server(Connection):
def get_state(self):
return {
'address': self.address,
'alpn_proto_negotiated': self.alpn,
'alpn': self.alpn,
'id': self.id,
'ip_address': self.peername,
'sni': self.sni,
@ -257,7 +259,7 @@ class Server(Connection):
def set_state(self, state):
self.address = tuple(state["address"]) if state["address"] else None
self.alpn = state["alpn_proto_negotiated"]
self.alpn = state["alpn"]
self.id = state["id"]
self.peername = tuple(state["ip_address"]) if state["ip_address"] else None
self.sni = state["sni"]

View File

@ -435,7 +435,7 @@ class HttpStream(layer.Layer):
stack = tunnel.LayerStack()
if self.context.server.via.scheme == "https":
http_proxy.sni = self.context.server.via.address[0].encode()
http_proxy.sni = self.context.server.via.address[0]
stack /= tls.ServerTLSLayer(self.context, http_proxy)
stack /= _upstream_proxy.HttpUpstreamProxy(self.context, http_proxy, True)
@ -635,7 +635,7 @@ class HttpLayer(layer.Layer):
context.server = Server(event.address)
if event.tls:
context.server.sni = event.address[0].encode()
context.server.sni = event.address[0]
if event.via:
assert event.via.scheme in ("http", "https")
@ -643,7 +643,7 @@ class HttpLayer(layer.Layer):
if event.via.scheme == "https":
http_proxy.alpn_offers = tls.HTTP_ALPNS
http_proxy.sni = event.via.address[0].encode()
http_proxy.sni = event.via.address[0]
stack /= tls.ServerTLSLayer(context, http_proxy)
send_connect = not (self.mode == HTTPMode.upstream and not event.tls)

View File

@ -42,7 +42,7 @@ class ReverseProxy(DestinationKnown):
if spec.scheme not in ("http", "tcp"):
if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0].encode()
self.context.server.sni = spec.address[0]
self.child_layer = tls.ServerTLSLayer(self.context)
else:
self.child_layer = layer.NextLayer(self.context)

View File

@ -4,7 +4,6 @@ from dataclasses import dataclass
from typing import Iterator, Optional, Tuple
from OpenSSL import SSL
from mitmproxy import certs
from mitmproxy.net import tls as net_tls
from mitmproxy.proxy import commands, events, layer, tunnel
@ -184,6 +183,8 @@ class _TLSLayer(tunnel.TunnelLayer):
err = f"OpenSSL {e!r}"
return False, err
else:
# Here we set all attributes that are only known *after* the handshake.
# Get all peer certificates.
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
# If called on the client side, the stack also contains the peer's certificate; if called on the server
@ -195,9 +196,8 @@ class _TLSLayer(tunnel.TunnelLayer):
all_certs.insert(0, cert)
self.conn.timestamp_tls_setup = time.time()
self.conn.sni = self.tls.get_servername()
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
self.conn.certificate_list = [certs.Cert(x) for x in all_certs]
self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs]
self.conn.cipher = self.tls.get_cipher_name()
self.conn.tls_version = self.tls.get_protocol_version_name()
if self.debug:
@ -338,8 +338,7 @@ class ClientTLSLayer(_TLSLayer):
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
if self.conn.sni:
assert isinstance(self.conn.sni, bytes)
dest = self.conn.sni.decode("idna")
dest = self.conn.sni
else:
dest = human.format_address(self.context.server.address)
if err.startswith("Cannot parse ClientHello"):

View File

@ -427,7 +427,8 @@ if __name__ == "__main__": # pragma: no cover
tls_start.ssl_conn.set_accept_state()
else:
tls_start.ssl_conn.set_connect_state()
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni)
if tls_start.context.client.sni is not None:
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni.encode())
await SimpleConnectionHandler(reader, writer, opts, {
"next_layer": next_layer,

View File

@ -158,7 +158,7 @@ def tclient_conn() -> context.Client:
timestamp_end=946681206,
sni="address",
cipher_name="cipher",
alpn_proto_negotiated=b"http/1.1",
alpn=b"http/1.1",
tls_version="TLSv1.2",
tls_extensions=[(0x00, bytes.fromhex("000e00000b6578616d"))],
state=0,
@ -185,7 +185,7 @@ def tserver_conn() -> context.Server:
timestamp_end=946681205,
tls_established=True,
sni="address",
alpn_proto_negotiated=None,
alpn=None,
tls_version="TLSv1.2",
via=None,
state=0,

View File

@ -4,8 +4,7 @@ import urwid
import mitmproxy.flow
from mitmproxy import http
from mitmproxy.tools.console import common, searchable
from mitmproxy.utils import human
from mitmproxy.utils import strutils
from mitmproxy.utils import human, strutils
def maybe_timestamp(base, attr):
@ -40,64 +39,37 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
text.append(urwid.Text([("head", "Metadata:")]))
text.extend(common.format_keyvals(parts, indent=4))
if sc is not None and sc.ip_address:
if sc is not None and sc.peername:
text.append(urwid.Text([("head", "Server Connection:")]))
parts = [
("Address", human.format_address(sc.address)),
]
if sc.ip_address:
parts.append(("Resolved Address", human.format_address(sc.ip_address)))
if sc.peername:
parts.append(("Resolved Address", human.format_address(sc.peername)))
if resp:
parts.append(("HTTP Version", resp.http_version))
if sc.alpn_proto_negotiated:
parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn_proto_negotiated)))
if sc.alpn:
parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn)))
text.extend(
common.format_keyvals(parts, indent=4)
)
c = sc.cert
c = sc.certificate_list[0]
if c:
text.append(urwid.Text([("head", "Server Certificate:")]))
parts = [
("Type", "%s, %s bits" % c.keyinfo),
("SHA1 digest", c.digest("sha1")),
("SHA256 digest", c.fingerprint().hex()),
("Valid to", str(c.notafter)),
("Valid from", str(c.notbefore)),
("Serial", str(c.serial)),
(
"Subject",
urwid.BoxAdapter(
urwid.ListBox(
common.format_keyvals(
c.subject,
key_format="highlight"
)
),
len(c.subject)
)
),
(
"Issuer",
urwid.BoxAdapter(
urwid.ListBox(
common.format_keyvals(
c.issuer,
key_format="highlight"
)
),
len(c.issuer)
)
)
("Subject", urwid.Pile(common.format_keyvals(c.subject, key_format="highlight"))),
("Issuer", urwid.Pile(common.format_keyvals(c.issuer, key_format="highlight")))
]
if c.altnames:
parts.append(
(
"Alt names",
", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames)
)
)
parts.append(("Alt names", ", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames)))
text.extend(
common.format_keyvals(parts, indent=4)
)
@ -106,19 +78,18 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
text.append(urwid.Text([("head", "Client Connection:")]))
parts = [
("Address", "{}:{}".format(cc.address[0], cc.address[1])),
("Address", human.format_address(cc.peername)),
]
if req:
parts.append(("HTTP Version", req.http_version))
if cc.tls_version:
parts.append(("TLS Version", cc.tls_version))
if cc.sni:
parts.append(("Server Name Indication",
strutils.bytes_to_escaped_str(strutils.always_bytes(cc.sni, "idna"))))
if cc.cipher_name:
parts.append(("Cipher Name", cc.cipher_name))
if cc.alpn_proto_negotiated:
parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn_proto_negotiated)))
parts.append(("Server Name Indication", cc.sni))
if cc.cipher:
parts.append(("Cipher Name", cc.cipher))
if cc.alpn:
parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn)))
text.extend(
common.format_keyvals(parts, indent=4)

View File

@ -47,8 +47,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"timestamp_start": flow.client_conn.timestamp_start,
"timestamp_tls_setup": flow.client_conn.timestamp_tls_setup,
"timestamp_end": flow.client_conn.timestamp_end,
# ideally idna, but we don't want errors
"sni": always_str(flow.client_conn.sni, "ascii", "backslashreplace"),
"sni": flow.client_conn.sni,
"cipher_name": flow.client_conn.cipher,
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
"tls_version": flow.client_conn.tls_version,
@ -61,18 +60,14 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"ip_address": flow.server_conn.peername,
"source_address": flow.server_conn.sockname,
"tls_established": flow.server_conn.tls_established,
"alpn_proto_negotiated": always_str(flow.server_conn.alpn, "ascii", "backslashreplace"),
"sni": flow.server_conn.sni,
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
"tls_version": flow.server_conn.tls_version,
"timestamp_start": flow.server_conn.timestamp_start,
"timestamp_tcp_setup": flow.server_conn.timestamp_tcp_setup,
"timestamp_tls_setup": flow.server_conn.timestamp_tls_setup,
"timestamp_end": flow.server_conn.timestamp_end,
}
if flow.server_conn.sni is True:
f["server_conn"] = None
else:
# ideally idna, but we don't want errors
f["server_conn"] = always_str(flow.server_conn.sni, "ascii", "backslashreplace")
if flow.error:
f["error"] = flow.error.get_state()

View File

@ -7,7 +7,7 @@ MITMPROXY = "mitmproxy " + VERSION
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
# for each change in the file format.
FLOW_FORMAT_VERSION = 10
FLOW_FORMAT_VERSION = 11
def get_dev_version() -> str:

View File

@ -80,7 +80,6 @@ setup(
"msgpack>=1.0.0, <1.1.0",
"passlib>=1.6.5, <1.8",
"protobuf>=3.14,<3.15",
"pyasn1>=0.3.1,<0.5",
"pyOpenSSL>=20.0,<20.1",
"pyparsing>=2.4.2,<2.5",
"pyperclip>=1.6.0,<1.9",

View File

@ -0,0 +1,21 @@
import secrets
from pathlib import Path
import objgraph
from mitmproxy import certs
if __name__ == "__main__":
store = certs.CertStore.from_store(path=Path("~/.mitmproxy/").expanduser(), basename="mitmproxy", key_size=2048)
store.STORE_CAP = 5
for _ in range(5):
store.get_cert(commonname=secrets.token_hex(16).encode(), sans=[], organization=None)
objgraph.show_growth()
for _ in range(20):
store.get_cert(commonname=secrets.token_hex(16).encode(), sans=[], organization=None)
print("====")
objgraph.show_growth()

View File

@ -57,21 +57,21 @@ class TestTlsConfig:
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
# Edge case first: We don't have _any_ idea about the server, so we just return "mitmproxy" as subject.
cert, pkey, chainfile = ta.get_cert(ctx)
assert cert.cn == b"mitmproxy"
entry = ta.get_cert(ctx)
assert entry.cert.cn == "mitmproxy"
# Here we have an existing server connection...
ctx.server.address = ("server-address.example", 443)
with open(tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.crt"), "rb") as f:
ctx.server.certificate_list = [certs.Cert.from_pem(f.read())]
cert, pkey, chainfile = ta.get_cert(ctx)
assert cert.cn == b"example.mitmproxy.org"
assert cert.altnames == [b"example.mitmproxy.org", b"server-address.example"]
entry = ta.get_cert(ctx)
assert entry.cert.cn == "example.mitmproxy.org"
assert entry.cert.altnames == ["example.mitmproxy.org", "server-address.example"]
# And now we also incorporate SNI.
ctx.client.sni = b"sni.example"
cert, pkey, chainfile = ta.get_cert(ctx)
assert cert.altnames == [b"example.mitmproxy.org", b"sni.example"]
ctx.client.sni = "sni.example"
entry = ta.get_cert(ctx)
assert entry.cert.altnames == ["example.mitmproxy.org", "sni.example"]
def test_tls_clienthello(self):
# only really testing for coverage here, there's no point in mirroring the individual conditions
@ -222,7 +222,7 @@ class TestTlsConfig:
@pytest.mark.asyncio
async def test_ca_expired(self, monkeypatch):
monkeypatch.setattr(SSL.X509, "has_expired", lambda self: True)
monkeypatch.setattr(certs.Cert, "has_expired", lambda self: True)
ta = tlsconfig.TlsConfig()
with taddons.context(ta) as tctx:
ta.configure(["confdir"])

View File

@ -0,0 +1,14 @@
-----BEGIN DH PARAMETERS-----
MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3
O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv
j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ
Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB
chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC
ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq
o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX
IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv
A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8
6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I
rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI=
-----END DH PARAMETERS-----

View File

@ -0,0 +1,19 @@
-----BEGIN CERTIFICATE-----
MIIDHTCCAtugAwIBAgIJAJNwd38WrLI/MAsGCWCGSAFlAwQDAjBFMQswCQYDVQQG
EwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lk
Z2l0cyBQdHkgTHRkMB4XDTIwMTIzMDE5MzkxMloXDTIxMDEyOTE5MzkxMlowRTEL
MAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVy
bmV0IFdpZGdpdHMgUHR5IEx0ZDCCAbcwggEsBgcqhkjOOAQBMIIBHwKBgQDL8YvJ
wqiJEdWKh7KBbgyCxpWhZdxyLFMguXwI2iX3+n8m2vg/6kBLK/sbI3y3qTeQOvQ0
V/WchWWAnGHT9CwB1MNhqMWd409ZAo49/m6IPYvoB1RctKyUE9E8jUad+jM2WGm9
zCG+KPeYFebTsWiQhhIgx8vPM657x6wA5y+omQIVAP4uMtD8Xv+TZCeXTB4j2bi0
yxjTAoGBAIRxOxIma4B4sEToblQjbwQh9gnWEhvvwAefu3Gcav12mRgYBVqPeskJ
0zROLFr9ubGZo0F/xBGCpaVjY+nDbE8/MwURLBfBr2wQTDbXL/Z5Ea1wNIURsqdo
VxXKADqPh/OfylObCONxF37RnSiABwlCPpIERj+g90oNSFLnvftiA4GEAAKBgC6X
5sBsCV71iCPZSM1UCiSOnzCg/6w5nHPawwcdPUtU6/mpOdylfHlttr4jA1c2a3TB
9ppxe1H0945nYFCHmU5zGtgwKPlNMS5GLrHtSbxiwB5dUm5AsNsLA0EVLF1Po7sI
xQRWTWp57iLzRGdDoKFiPaW07g/UnNzd6i/RAfcWo1MwUTAdBgNVHQ4EFgQUtX4q
D2f00P8e4zp9L+zOib+Fv10wHwYDVR0jBBgwFoAUtX4qD2f00P8e4zp9L+zOib+F
v10wDwYDVR0TAQH/BAUwAwEB/zALBglghkgBZQMEAwIDLwAwLAIURQwecWi0wMjS
EsmC5hxASeuH2UkCFG6zZfM0Sbms27vcrCfVEtxzhizU
-----END CERTIFICATE-----

View File

@ -0,0 +1,12 @@
-----BEGIN CERTIFICATE-----
MIIB1DCCAXqgAwIBAgIJALAxolM7r60uMAoGCCqGSM49BAMCMEUxCzAJBgNVBAYT
AkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn
aXRzIFB0eSBMdGQwHhcNMjAxMjMwMTkzOTIyWhcNMjEwMTI5MTkzOTIyWjBFMQsw
CQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJu
ZXQgV2lkZ2l0cyBQdHkgTHRkMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEbfF4
pQnC+cJvoOyp8g9scQsJLNBsd4Vd0ADCyLBx84B1WucmBfpqi+8ERo8e8Y899UKv
KWov2eqOUmGiCLS3kaNTMFEwHQYDVR0OBBYEFHRF3WQJ6x05/c29Rv0mPZn6CwKb
MB8GA1UdIwQYMBaAFHRF3WQJ6x05/c29Rv0mPZn6CwKbMA8GA1UdEwEB/wQFMAMB
Af8wCgYIKoZIzj0EAwIDSAAwRQIgVDA9XrBWlQQteO0LHlkXVpZH0QXjtnQi426Y
ZxGgG18CIQCVSh3iWEAsCXWWUwiVmqEbpWGNIwwpbK4wL7+5lvUvXw==
-----END CERTIFICATE-----

View File

@ -29,7 +29,7 @@ def test_sslkeylogfile(tdata, monkeypatch):
Path(tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.pem")),
Path(tdata.path("mitmproxy/net/data/dhparam.pem"))
)
cert, key, chain_file = store.get_cert(b"example.com", [], None)
entry = store.get_cert("example.com", [], None)
cctx = tls.create_proxy_server_context(
min_version=tls.DEFAULT_MIN_VERSION,
@ -46,9 +46,9 @@ def test_sslkeylogfile(tdata, monkeypatch):
min_version=tls.DEFAULT_MIN_VERSION,
max_version=tls.DEFAULT_MAX_VERSION,
cipher_list=None,
cert=cert,
key=key,
chain_file=chain_file,
cert=entry.cert,
key=entry.privatekey,
chain_file=entry.chain_file,
alpn_select_callback=None,
request_client_cert=False,
extra_chain_certs=(),
@ -105,7 +105,7 @@ class TestClientHello:
)
c = tls.ClientHello(data)
assert repr(c)
assert c.sni == b'example.com'
assert c.sni == 'example.com'
assert c.cipher_suites == [
49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161,
49171, 49162, 49172, 156, 157, 47, 53, 10

View File

@ -202,7 +202,7 @@ def test_reverse_proxy_tcp_over_tls(tctx: Context, monkeypatch, patch, connectio
>> reply_tls_start()
<< SendData(tctx.server, data)
)
assert tls.parse_client_hello(data()).sni == b"localhost"
assert tls.parse_client_hello(data()).sni == "localhost"
@pytest.mark.parametrize("connection_strategy", ["eager", "lazy"])

View File

@ -76,7 +76,7 @@ def test_get_client_hello():
def test_parse_client_hello():
assert tls.parse_client_hello(client_hello_with_extensions).sni == b"example.com"
assert tls.parse_client_hello(client_hello_with_extensions).sni == "example.com"
assert tls.parse_client_hello(client_hello_with_extensions[:50]) is None
with pytest.raises(ValueError):
tls.parse_client_hello(client_hello_with_extensions[:183] + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00')
@ -188,7 +188,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
tls_start.ssl_conn = SSL.Connection(ssl_context)
tls_start.ssl_conn.set_connect_state()
# Set SNI
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni)
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode())
# Manually enable hostname verification.
# Recent OpenSSL versions provide slightly nicer ways to do this, but they are not exposed in
@ -202,7 +202,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
)
SSL._openssl_assert(
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni, 0) == 1
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode(), 0) == 1
)
return tutils.reply(*args, side_effect=make_conn, **kwargs)
@ -227,7 +227,7 @@ class TestServerTLS:
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.state = ConnectionState.OPEN
tctx.server.address = ("example.mitmproxy.org", 443)
tctx.server.sni = b"example.mitmproxy.org"
tctx.server.sni = "example.mitmproxy.org"
tssl = SSLTest(server_side=True)
@ -280,7 +280,7 @@ class TestServerTLS:
"""If the certificate is not trusted, we should fail."""
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.address = ("wrong.host.mitmproxy.org", 443)
tctx.server.sni = b"wrong.host.mitmproxy.org"
tctx.server.sni = "wrong.host.mitmproxy.org"
tssl = SSLTest(server_side=True)
@ -316,7 +316,7 @@ class TestServerTLS:
def test_remote_speaks_no_tls(self, tctx):
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.state = ConnectionState.OPEN
tctx.server.sni = b"example.mitmproxy.org"
tctx.server.sni = "example.mitmproxy.org"
# send ClientHello, receive random garbage back
data = tutils.Placeholder(bytes)
@ -345,7 +345,7 @@ def make_client_tls_layer(
# Add some server config, this is needed anyways.
tctx.server.address = ("example.mitmproxy.org", 443)
tctx.server.sni = b"example.mitmproxy.org"
tctx.server.sni = "example.mitmproxy.org"
tssl_client = SSLTest(**kwargs)
# Start handshake.

View File

@ -1,7 +1,12 @@
import os
from pathlib import Path
import pytest
from mitmproxy import certs
from ..conftest import skip_windows
# class TestDNTree:
# def test_simple(self):
# d = certs.DNTree()
@ -32,85 +37,87 @@ from ..conftest import skip_windows
# assert d.get("com") == "foo"
@pytest.fixture()
def tstore(tdata):
return certs.CertStore.from_store(tdata.path("mitmproxy/data/confdir"), "mitmproxy", 2048)
class TestCertStore:
def test_create_explicit(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
assert ca.get_cert(b"foo", [])
assert ca.get_cert("foo", [])
ca2 = certs.CertStore.from_store(str(tmpdir), "test", 2048)
assert ca2.get_cert(b"foo", [])
assert ca2.get_cert("foo", [])
assert ca.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
assert ca.default_ca.serial == ca2.default_ca.serial
def test_create_no_common_name(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
assert ca.get_cert(None, [])[0].cn is None
def test_create_no_common_name(self, tstore):
assert tstore.get_cert(None, []).cert.cn is None
def test_create_tmp(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
assert ca.get_cert(b"foo.com", [])
assert ca.get_cert(b"foo.com", [])
assert ca.get_cert(b"*.foo.com", [])
def test_chain_file(self, tdata, tmp_path):
cert = Path(tdata.path("mitmproxy/data/confdir/mitmproxy-ca.pem")).read_bytes()
(tmp_path / "mitmproxy-ca.pem").write_bytes(cert)
ca = certs.CertStore.from_store(tmp_path, "mitmproxy", 2048)
assert ca.default_chain_file is None
r = ca.get_cert(b"*.foo.com", [])
assert r[1] == ca.default_privatekey
(tmp_path / "mitmproxy-ca.pem").write_bytes(2 * cert)
ca = certs.CertStore.from_store(tmp_path, "mitmproxy", 2048)
assert ca.default_chain_file == (tmp_path / "mitmproxy-ca.pem")
def test_sans(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
c1 = ca.get_cert(b"foo.com", [b"*.bar.com"])
ca.get_cert(b"foo.bar.com", [])
def test_sans(self, tstore):
c1 = tstore.get_cert("foo.com", ["*.bar.com"])
tstore.get_cert("foo.bar.com", [])
# assert c1 == c2
c3 = ca.get_cert(b"bar.com", [])
c3 = tstore.get_cert("bar.com", [])
assert not c1 == c3
def test_sans_change(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
ca.get_cert(b"foo.com", [b"*.bar.com"])
cert, key, chain_file = ca.get_cert(b"foo.bar.com", [b"*.baz.com"])
assert b"*.baz.com" in cert.altnames
def test_sans_change(self, tstore):
tstore.get_cert("foo.com", ["*.bar.com"])
entry = tstore.get_cert("foo.bar.com", ["*.baz.com"])
assert "*.baz.com" in entry.cert.altnames
def test_expire(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
ca.STORE_CAP = 3
ca.get_cert(b"one.com", [])
ca.get_cert(b"two.com", [])
ca.get_cert(b"three.com", [])
def test_expire(self, tstore):
tstore.STORE_CAP = 3
tstore.get_cert("one.com", [])
tstore.get_cert("two.com", [])
tstore.get_cert("three.com", [])
assert (b"one.com", ()) in ca.certs
assert (b"two.com", ()) in ca.certs
assert (b"three.com", ()) in ca.certs
assert ("one.com", ()) in tstore.certs
assert ("two.com", ()) in tstore.certs
assert ("three.com", ()) in tstore.certs
ca.get_cert(b"one.com", [])
tstore.get_cert("one.com", [])
assert (b"one.com", ()) in ca.certs
assert (b"two.com", ()) in ca.certs
assert (b"three.com", ()) in ca.certs
assert ("one.com", ()) in tstore.certs
assert ("two.com", ()) in tstore.certs
assert ("three.com", ()) in tstore.certs
ca.get_cert(b"four.com", [])
tstore.get_cert("four.com", [])
assert (b"one.com", ()) not in ca.certs
assert (b"two.com", ()) in ca.certs
assert (b"three.com", ()) in ca.certs
assert (b"four.com", ()) in ca.certs
assert ("one.com", ()) not in tstore.certs
assert ("two.com", ()) in tstore.certs
assert ("three.com", ()) in tstore.certs
assert ("four.com", ()) in tstore.certs
def test_overrides(self, tmpdir):
ca1 = certs.CertStore.from_store(str(tmpdir.join("ca1")), "test", 2048)
ca2 = certs.CertStore.from_store(str(tmpdir.join("ca2")), "test", 2048)
assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
def test_overrides(self, tmp_path):
ca1 = certs.CertStore.from_store(tmp_path / "ca1", "test", 2048)
ca2 = certs.CertStore.from_store(tmp_path / "ca2", "test", 2048)
assert not ca1.default_ca.serial == ca2.default_ca.serial
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dcp = tmpdir.join("dc")
dcp.write(dc[0].to_pem())
ca1.add_cert_file("foo.com", str(dcp))
dc = ca2.get_cert("foo.com", ["sans.example.com"])
dcp = tmp_path / "dc"
dcp.write_bytes(dc.cert.to_pem())
ca1.add_cert_file("foo.com", dcp)
ret = ca1.get_cert(b"foo.com", [])
assert ret[0].serial == dc[0].serial
ret = ca1.get_cert("foo.com", [])
assert ret.cert.serial == dc.cert.serial
def test_create_dhparams(self, tmpdir):
filename = str(tmpdir.join("dhparam.pem"))
def test_create_dhparams(self, tmp_path):
filename = tmp_path / "dhparam.pem"
certs.CertStore.load_dhparam(filename)
assert os.path.exists(filename)
assert filename.exists()
@skip_windows
def test_umask_secret(self, tmpdir):
@ -123,22 +130,21 @@ class TestCertStore:
class TestDummyCert:
def test_with_ca(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
def test_with_ca(self, tstore):
r = certs.dummy_cert(
ca.default_privatekey,
ca.default_ca,
b"foo.com",
[b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"],
b"Foo Ltd."
tstore.default_privatekey,
tstore.default_ca._cert,
"foo.com",
["one.com", "two.com", "*.three.com", "127.0.0.1"],
"Foo Ltd."
)
assert r.cn == b"foo.com"
assert r.altnames == [b'one.com', b'two.com', b'*.three.com']
assert r.organization == b"Foo Ltd."
assert r.cn == "foo.com"
assert r.altnames == ["one.com", "two.com", "*.three.com", "127.0.0.1"]
assert r.organization == "Foo Ltd."
r = certs.dummy_cert(
ca.default_privatekey,
ca.default_ca,
tstore.default_privatekey,
tstore.default_ca._cert,
None,
[],
None
@ -154,16 +160,16 @@ class TestCert:
with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f:
d = f.read()
c1 = certs.Cert.from_pem(d)
assert c1.cn == b"google.com"
assert c1.cn == "google.com"
assert len(c1.altnames) == 436
assert c1.organization == b"Google Inc"
assert c1.organization == "Google Inc"
with open(tdata.path("mitmproxy/net/data/text_cert_2"), "rb") as f:
d = f.read()
c2 = certs.Cert.from_pem(d)
assert c2.cn == b"www.inode.co.nz"
assert c2.cn == "www.inode.co.nz"
assert len(c2.altnames) == 2
assert c2.digest("sha1")
assert c2.fingerprint()
assert c2.notbefore
assert c2.notafter
assert c2.subject
@ -171,10 +177,30 @@ class TestCert:
assert c2.serial
assert c2.issuer
assert c2.to_pem()
assert c2.has_expired is not None
assert c2.has_expired() is not None
assert c1 != c2
def test_convert(self, tdata):
with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f:
d = f.read()
c = certs.Cert.from_pem(d)
assert c == certs.Cert.from_pem(c.to_pem())
assert c == certs.Cert.from_state(c.get_state())
assert c == certs.Cert.from_pyopenssl(c.to_pyopenssl())
@pytest.mark.parametrize("filename,name,bits", [
("text_cert", "RSA", 1024),
("dsa_cert.pem", "DSA", 1024),
("ec_cert.pem", "EC (secp256r1)", 256),
])
def test_keyinfo(self, tdata, filename, name, bits):
with open(tdata.path(f"mitmproxy/net/data/{filename}"), "rb") as f:
d = f.read()
c = certs.Cert.from_pem(d)
assert c.keyinfo == (name, bits)
def test_err_broken_sans(self, tdata):
with open(tdata.path("mitmproxy/net/data/text_cert_weird1"), "rb") as f:
d = f.read()
@ -182,12 +208,6 @@ class TestCert:
# This breaks unless we ignore a decoding error.
assert c.altnames is not None
def test_der(self, tdata):
with open(tdata.path("mitmproxy/net/data/dercert"), "rb") as f:
d = f.read()
s = certs.Cert.from_der(d)
assert s.cn
def test_state(self, tdata):
with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f:
d = f.read()
@ -201,12 +221,13 @@ class TestCert:
assert c == c2
assert c is not c2
x = certs.Cert('')
x.set_state(a)
assert x == c
c2.set_state(a)
assert c == c2
def test_from_store_with_passphrase(self, tdata, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, b"password")
ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), b"password")
def test_from_store_with_passphrase(self, tdata, tstore):
tstore.add_cert_file("unencrypted-no-pass", Path(tdata.path("mitmproxy/data/testkey.pem")), None)
tstore.add_cert_file("unencrypted-pass", Path(tdata.path("mitmproxy/data/testkey.pem")), b"password")
tstore.add_cert_file("encrypted-pass", Path(tdata.path("mitmproxy/data/mitmproxy.pem")), b"password")
assert ca.get_cert(b"foo", [])
with pytest.raises(TypeError):
tstore.add_cert_file("encrypted-no-pass", Path(tdata.path("mitmproxy/data/mitmproxy.pem")), None)

View File

@ -25,12 +25,10 @@ def test_format_keyvals():
("ee", None),
]
)
wrapped = urwid.BoxAdapter(
urwid.ListBox(
wrapped = urwid.Pile(
urwid.SimpleFocusListWalker(
common.format_keyvals([("foo", "bar")])
)
), 1
)
assert wrapped.render((30,))
assert common.format_keyvals(