From 2db9a43fd6af1e9680e302ba456e82222298470d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 28 Dec 2020 09:44:37 +0100 Subject: [PATCH] add type annotations, test sslkeylogfile --- mitmproxy/addons/tlsconfig.py | 54 ++++++++++------- mitmproxy/certs.py | 91 ++++++++++++++-------------- mitmproxy/net/tls.py | 34 ++++++----- test/mitmproxy/net/test_tls.py | 107 +++++++++++++++------------------ test/mitmproxy/test_certs.py | 4 +- 5 files changed, 149 insertions(+), 141 deletions(-) diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index d1c8f5551..c05d91e24 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -1,5 +1,4 @@ import os -from pathlib import Path from typing import List, Optional, Tuple, TypedDict, Any from OpenSSL import SSL, crypto @@ -61,21 +60,34 @@ class TlsConfig: # - ssl_verify_upstream_trusted_confdir def load(self, loader): - for c in ["client", "server"]: - loader.add_option( - name=f"tls_version_{c}_min", - typespec=str, - default=net_tls.DEFAULT_MIN_VERSION.name, - choices=[x.name for x in net_tls.Version], - help=f"Set the minimum TLS version for {c} connections.", - ) - loader.add_option( - name=f"tls_version_{c}_max", - typespec=str, - default=net_tls.DEFAULT_MAX_VERSION.name, - choices=[x.name for x in net_tls.Version], - help=f"Set the maximum TLS version for {c} connections.", - ) + loader.add_option( + name="tls_version_client_min", + typespec=str, + default=net_tls.DEFAULT_MIN_VERSION.name, + choices=[x.name for x in net_tls.Version], + help=f"Set the minimum TLS version for client connections.", + ) + loader.add_option( + name="tls_version_client_max", + typespec=str, + default=net_tls.DEFAULT_MAX_VERSION.name, + choices=[x.name for x in net_tls.Version], + help=f"Set the maximum TLS version for client connections.", + ) + loader.add_option( + name="tls_version_server_min", + typespec=str, + default=net_tls.DEFAULT_MIN_VERSION.name, + choices=[x.name for x in net_tls.Version], + help=f"Set the minimum TLS version for server connections.", + ) + loader.add_option( + name="tls_version_server_max", + typespec=str, + default=net_tls.DEFAULT_MAX_VERSION.name, + choices=[x.name for x in net_tls.Version], + help=f"Set the maximum TLS version for server connections.", + ) def tls_clienthello(self, tls_clienthello: tls.ClientHelloData): conn_context = tls_clienthello.context @@ -163,15 +175,15 @@ class TlsConfig: # don't assign to client.cipher_list, doesn't need to be stored. cipher_list = server.cipher_list or DEFAULT_CIPHERS - client_cert: Optional[Path] = None + client_cert: Optional[str] = None if ctx.options.client_certs: - client_certs = Path(ctx.options.client_certs).expanduser() - if client_certs.is_file(): + client_certs = os.path.expanduser(ctx.options.client_certs) + if os.path.isfile(client_certs): client_cert = client_certs else: server_name: str = (server.sni or server.address[0].encode("idna")).decode() - p = (client_certs / f"{server_name}.pem") - if p.is_file(): + p = os.path.join(client_certs, f"{server_name}.pem") + if os.path.isfile(p): client_cert = p ssl_ctx = net_tls.create_proxy_server_context( diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index 31ae08b28..ece050789 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -4,8 +4,9 @@ import time import datetime import ipaddress import sys -import typing import contextlib +from pathlib import Path +from typing import Tuple, Optional, Union, Dict, List from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.codec.der.decoder import decode @@ -36,7 +37,7 @@ rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI= """ -def create_ca(organization, cn, exp, key_size): +def create_ca(organization: str, cn: str, exp: int, key_size: int) -> Tuple[OpenSSL.crypto.PKey, OpenSSL.crypto.X509]: key = OpenSSL.crypto.PKey() key.generate_key(OpenSSL.crypto.TYPE_RSA, key_size) cert = OpenSSL.crypto.X509() @@ -145,8 +146,8 @@ class CertStoreEntry: TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs) -TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans) -TCertId = typing.Union[TCustomCertId, TGeneratedCertId] +TGeneratedCertId = Tuple[Optional[bytes], Tuple[bytes, ...]] # (common_name, sans) +TCertId = Union[TCustomCertId, TGeneratedCertId] class CertStore: @@ -166,7 +167,7 @@ class CertStore: self.default_ca = default_ca self.default_chain_file = default_chain_file self.dhparams = dhparams - self.certs: typing.Dict[TCertId, CertStoreEntry] = {} + self.certs: Dict[TCertId, CertStoreEntry] = {} self.expire_queue = [] def expire(self, entry): @@ -176,7 +177,7 @@ class CertStore: self.certs = {k: v for k, v in self.certs.items() if v != d} @staticmethod - def load_dhparam(path): + def load_dhparam(path: str): # mitmproxy<=0.10 doesn't generate a dhparam file. # Create it now if necessary. @@ -196,23 +197,28 @@ class CertStore: return dh @classmethod - def from_store(cls, path, basename, key_size, passphrase: typing.Optional[bytes] = None): - ca_path = os.path.join(path, basename + "-ca.pem") - if not os.path.exists(ca_path): - key, ca = cls.create_store(path, basename, key_size) - else: - with open(ca_path, "rb") as f: - raw = f.read() - ca = OpenSSL.crypto.load_certificate( - OpenSSL.crypto.FILETYPE_PEM, - raw) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - raw, - passphrase) - dh_path = os.path.join(path, basename + "-dhparam.pem") - dh = cls.load_dhparam(dh_path) - return cls(key, ca, ca_path, dh) + def from_store(cls, path: Union[Path, str], basename: str, key_size, passphrase: Optional[bytes] = None) -> "CertStore": + path = Path(path) + ca_file = path / f"{basename}-ca.pem" + dhparam_file = path / f"{basename}-dhparam.pem" + if not ca_file.exists(): + cls.create_store(path, basename, key_size) + return cls.from_files(ca_file, dhparam_file, passphrase) + + @classmethod + def from_files(cls, ca_file: Path, dhparam_file: Path, passphrase: Optional[bytes] = None) -> "CertStore": + raw = ca_file.read_bytes() + ca = OpenSSL.crypto.load_certificate( + OpenSSL.crypto.FILETYPE_PEM, + raw + ) + key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, + raw, + passphrase + ) + dh = cls.load_dhparam(str(dhparam_file)) + return cls(key, ca, str(ca_file), dh) @staticmethod @contextlib.contextmanager @@ -230,16 +236,15 @@ class CertStore: os.umask(original_umask) @staticmethod - def create_store(path, basename, key_size, organization=None, cn=None, expiry=DEFAULT_EXP): - if not os.path.exists(path): - os.makedirs(path) + def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None, expiry=DEFAULT_EXP) -> None: + path.mkdir(parents=True, exist_ok=True) organization = organization or basename cn = cn or basename key, ca = create_ca(organization=organization, cn=cn, exp=expiry, key_size=key_size) # Dump the CA plus private key - with CertStore.umask_secret(), open(os.path.join(path, basename + "-ca.pem"), "wb") as f: + with CertStore.umask_secret(), (path / f"{basename}-ca.pem").open("wb") as f: f.write( OpenSSL.crypto.dump_privatekey( OpenSSL.crypto.FILETYPE_PEM, @@ -250,38 +255,36 @@ class CertStore: ca)) # Dump the certificate in PEM format - with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: + with (path / f"{basename}-ca-cert.pem").open("wb") as f: f.write( OpenSSL.crypto.dump_certificate( OpenSSL.crypto.FILETYPE_PEM, ca)) # Create a .cer file with the same contents for Android - with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: + with (path / f"{basename}-ca-cert.cer").open("wb") as f: f.write( OpenSSL.crypto.dump_certificate( OpenSSL.crypto.FILETYPE_PEM, ca)) # Dump the certificate in PKCS12 format for Windows devices - with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: + with (path / f"{basename}-ca-cert.p12").open("wb") as f: p12 = OpenSSL.crypto.PKCS12() p12.set_certificate(ca) f.write(p12.export()) # Dump the certificate and key in a PKCS12 format for Windows devices - with CertStore.umask_secret(), open(os.path.join(path, basename + "-ca.p12"), "wb") as f: + with CertStore.umask_secret(), (path / f"{basename}-ca.p12").open("wb") as f: p12 = OpenSSL.crypto.PKCS12() p12.set_certificate(ca) p12.set_privatekey(key) f.write(p12.export()) - with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: + with (path / f"{basename}-dhparam.pem").open("wb") as f: f.write(DEFAULT_DHPARAM) - return key, ca - - def add_cert_file(self, spec: str, path: str, passphrase: typing.Optional[bytes] = None) -> None: + def add_cert_file(self, spec: str, path: str, passphrase: Optional[bytes] = None) -> None: with open(path, "rb") as f: raw = f.read() cert = Cert( @@ -313,7 +316,7 @@ class CertStore: self.certs[i] = entry @staticmethod - def asterisk_forms(dn: bytes) -> typing.List[bytes]: + def asterisk_forms(dn: bytes) -> List[bytes]: """ Return all asterisk forms for a domain. For example, for www.example.com this will return [b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted. @@ -326,10 +329,10 @@ class CertStore: def get_cert( self, - commonname: typing.Optional[bytes], - sans: typing.List[bytes], - organization: typing.Optional[bytes] = None - ) -> typing.Tuple["Cert", OpenSSL.SSL.PKey, str]: + commonname: Optional[bytes], + sans: List[bytes], + organization: Optional[bytes] = None + ) -> Tuple["Cert", OpenSSL.SSL.PKey, str]: """ Returns an (cert, privkey, cert_chain) tuple. @@ -341,7 +344,7 @@ class CertStore: organization: Organization name for the generated certificate. """ - potential_keys: typing.List[TCertId] = [] + potential_keys: List[TCertId] = [] if commonname: potential_keys.extend(self.asterisk_forms(commonname)) for s in sans: @@ -467,7 +470,7 @@ class Cert(serializable.Serializable): ) @property - def cn(self) -> typing.Optional[bytes]: + def cn(self) -> Optional[bytes]: c = None for i in self.subject: if i[0] == b"CN": @@ -475,7 +478,7 @@ class Cert(serializable.Serializable): return c @property - def organization(self) -> typing.Optional[bytes]: + def organization(self) -> Optional[bytes]: c = None for i in self.subject: if i[0] == b"O": @@ -483,7 +486,7 @@ class Cert(serializable.Serializable): return c @property - def altnames(self) -> typing.List[bytes]: + def altnames(self) -> List[bytes]: """ Returns: All DNS altnames. diff --git a/mitmproxy/net/tls.py b/mitmproxy/net/tls.py index 473a83ef5..a63fe971e 100644 --- a/mitmproxy/net/tls.py +++ b/mitmproxy/net/tls.py @@ -3,7 +3,7 @@ import os import threading from enum import Enum from pathlib import Path -from typing import Iterable, Callable, Optional, Tuple, List, Any +from typing import Iterable, Callable, Optional, Tuple, List, Any, BinaryIO import certifi from kaitaistruct import KaitaiStream @@ -22,6 +22,11 @@ class Method(Enum): # TODO: remove once https://github.com/pyca/pyopenssl/pull/985 has landed. +try: + SSL._lib.TLS_server_method +except AttributeError as e: # pragma: no cover + raise RuntimeError("Your installation of the cryptography Python package is outdated.") from e + SSL.Context._methods.setdefault(Method.TLS_SERVER_METHOD.value, SSL._lib.TLS_server_method) SSL.Context._methods.setdefault(Method.TLS_CLIENT_METHOD.value, SSL._lib.TLS_client_method) @@ -52,7 +57,7 @@ DEFAULT_OPTIONS = ( class MasterSecretLogger: def __init__(self, filename: Path): self.filename = filename.expanduser() - self.f = None + self.f: Optional[BinaryIO] = None self.lock = threading.Lock() # required for functools.wraps, which pyOpenSSL uses. @@ -89,7 +94,7 @@ def _create_ssl_context( method: Method, min_version: Version, max_version: Version, - cipher_list: List[str], + cipher_list: Optional[Iterable[str]], ) -> SSL.Context: context = SSL.Context(method.value) @@ -105,7 +110,7 @@ def _create_ssl_context( context.set_options(DEFAULT_OPTIONS) # Cipher List - if cipher_list: + if cipher_list is not None: try: context.set_cipher_list(b":".join(x.encode() for x in cipher_list)) except SSL.Error as v: @@ -122,13 +127,13 @@ def create_proxy_server_context( *, min_version: Version, max_version: Version, - cipher_list: List[str], + cipher_list: Optional[Iterable[str]], verify: Verify, sni: Optional[bytes], - ca_path: Path, - ca_pemfile: Path, - client_cert: Path, - alpn_protos: Iterable[bytes], + ca_path: Optional[str], + ca_pemfile: Optional[str], + client_cert: Optional[str], + alpn_protos: Optional[Iterable[bytes]], ) -> SSL.Context: context: SSL.Context = _create_ssl_context( method=Method.TLS_CLIENT_METHOD, @@ -165,8 +170,8 @@ def create_proxy_server_context( # Client Certs if client_cert: try: - context.use_privatekey_file(str(client_cert)) - context.use_certificate_chain_file(str(client_cert)) + context.use_privatekey_file(client_cert) + context.use_certificate_chain_file(client_cert) except SSL.Error as v: raise exceptions.TlsException(f"TLS client certificate error: {v}") @@ -181,11 +186,11 @@ def create_client_proxy_context( *, min_version: Version, max_version: Version, - cipher_list: List[str], + cipher_list: Optional[Iterable[str]], cert: certs.Cert, key: SSL.PKey, chain_file: str, - alpn_select_callback: Callable[[SSL.Connection, List[bytes]], Any], + alpn_select_callback: Optional[Callable[[SSL.Connection, List[bytes]], Any]], request_client_cert: bool, extra_chain_certs: Iterable[certs.Cert], dhparams, @@ -249,8 +254,9 @@ def is_tls_record_magic(d): """ d = d[:3] - # TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2 + # TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2, and TLSv1.3 # http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello + # https://tls13.ulfheim.net/ return ( len(d) == 3 and d[0] == 0x16 and diff --git a/test/mitmproxy/net/test_tls.py b/test/mitmproxy/net/test_tls.py index 6b843c811..c12b45128 100644 --- a/test/mitmproxy/net/test_tls.py +++ b/test/mitmproxy/net/test_tls.py @@ -1,3 +1,7 @@ +from pathlib import Path + +from OpenSSL import SSL +from mitmproxy import certs from mitmproxy.net import tls CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex( @@ -17,76 +21,59 @@ def test_make_master_secret_logger(): assert isinstance(tls.make_master_secret_logger("filepath"), tls.MasterSecretLogger) -""" def test_sslkeylogfile(tdata, monkeypatch): keylog = [] monkeypatch.setattr(tls, "log_master_secret", lambda conn, secrets: keylog.append(secrets)) - ctx = tls.create_client_context() + store = certs.CertStore.from_files( + Path(tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.pem")), + Path(tdata.path("mitmproxy/net/data/dhparam.pem")) + ) + cert, key, chain_file = store.get_cert(b"example.com", [], None) - ta = tlsconfig.TlsConfig() - with taddons.context(ta) as tctx: - ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options) - ctx.server.address = ("example.mitmproxy.org", 443) - tctx.configure(ta, ssl_verify_upstream_trusted_ca=tdata.path( - "mitmproxy/net/data/verificationcerts/trusted-root.crt")) - - tls_start = tls.TlsStartData(ctx.server, context=ctx) - ta.tls_start(tls_start) - tssl_client = tls_start.ssl_conn - tssl_server = test_tls.SSLTest(server_side=True) - assert self.do_handshake(tssl_client, tssl_server) -""" - -""" -class TestMasterSecretLogger(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - cipher_list="AES256-SHA" + cctx = tls.create_proxy_server_context( + min_version=tls.DEFAULT_MIN_VERSION, + max_version=tls.DEFAULT_MAX_VERSION, + cipher_list=None, + verify=tls.Verify.VERIFY_NONE, + sni=None, + ca_path=None, + ca_pemfile=None, + client_cert=None, + alpn_protos=(), + ) + sctx = tls.create_client_proxy_context( + min_version=tls.DEFAULT_MIN_VERSION, + max_version=tls.DEFAULT_MAX_VERSION, + cipher_list=None, + cert=cert, + key=key, + chain_file=chain_file, + alpn_select_callback=None, + request_client_cert=False, + extra_chain_certs=(), + dhparams=store.dhparams, ) - def test_log(self, tmpdir): - testval = b"echo!\n" - _logfun = tls.log_master_secret + server = SSL.Connection(sctx) + server.set_accept_state() - logfile = str(tmpdir.join("foo", "bar", "logfile")) - tls.log_master_secret = tls.MasterSecretLogger(logfile) + client = SSL.Connection(cctx) + client.set_connect_state() - c = TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls() - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - c.finish() + read, write = client, server + while True: + try: + print(read) + read.do_handshake() + except SSL.WantReadError: + write.bio_write(read.bio_read(2 ** 16)) + else: + break + read, write = write, read - tls.log_master_secret.close() - with open(logfile, "rb") as f: - assert f.read().count(b"SERVER_HANDSHAKE_TRAFFIC_SECRET") >= 2 - - tls.log_master_secret = _logfun - - def test_create_logfun(self): - assert isinstance( - tls.MasterSecretLogger.create_logfun("test"), - tls.MasterSecretLogger) - assert not tls.MasterSecretLogger.create_logfun(False) - - - -class TestTLSInvalid: - def test_invalid_ssl_method_should_fail(self): - fake_ssl_method = 100500 - with pytest.raises(exceptions.TlsException): - tls.create_proxy_server_context(method=fake_ssl_method) - - def test_alpn_error(self): - with pytest.raises(exceptions.TlsException, match="must be a function"): - tls.create_proxy_server_context(alpn_select_callback="foo") - - with pytest.raises(exceptions.TlsException, match="ALPN error"): - tls.create_proxy_server_context(alpn_select="foo", alpn_select_callback="bar") -""" + assert keylog + assert keylog[0].startswith(b"SERVER_HANDSHAKE_TRAFFIC_SECRET") def test_is_record_magic(): diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index 3a977f3c9..72d891938 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -206,7 +206,7 @@ class TestCert: assert x == c def test_from_store_with_passphrase(self, tdata, tmpdir): - ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, "password") - ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), "password") + ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, b"password") + ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), b"password") assert ca.get_cert(b"foo", [])