add type annotations, test sslkeylogfile

This commit is contained in:
Maximilian Hils 2020-12-28 09:44:37 +01:00
parent de46db53e9
commit 2db9a43fd6
5 changed files with 149 additions and 141 deletions

View File

@ -1,5 +1,4 @@
import os import os
from pathlib import Path
from typing import List, Optional, Tuple, TypedDict, Any from typing import List, Optional, Tuple, TypedDict, Any
from OpenSSL import SSL, crypto from OpenSSL import SSL, crypto
@ -61,20 +60,33 @@ class TlsConfig:
# - ssl_verify_upstream_trusted_confdir # - ssl_verify_upstream_trusted_confdir
def load(self, loader): def load(self, loader):
for c in ["client", "server"]:
loader.add_option( loader.add_option(
name=f"tls_version_{c}_min", name="tls_version_client_min",
typespec=str, typespec=str,
default=net_tls.DEFAULT_MIN_VERSION.name, default=net_tls.DEFAULT_MIN_VERSION.name,
choices=[x.name for x in net_tls.Version], choices=[x.name for x in net_tls.Version],
help=f"Set the minimum TLS version for {c} connections.", help=f"Set the minimum TLS version for client connections.",
) )
loader.add_option( loader.add_option(
name=f"tls_version_{c}_max", name="tls_version_client_max",
typespec=str, typespec=str,
default=net_tls.DEFAULT_MAX_VERSION.name, default=net_tls.DEFAULT_MAX_VERSION.name,
choices=[x.name for x in net_tls.Version], choices=[x.name for x in net_tls.Version],
help=f"Set the maximum TLS version for {c} connections.", help=f"Set the maximum TLS version for client connections.",
)
loader.add_option(
name="tls_version_server_min",
typespec=str,
default=net_tls.DEFAULT_MIN_VERSION.name,
choices=[x.name for x in net_tls.Version],
help=f"Set the minimum TLS version for server connections.",
)
loader.add_option(
name="tls_version_server_max",
typespec=str,
default=net_tls.DEFAULT_MAX_VERSION.name,
choices=[x.name for x in net_tls.Version],
help=f"Set the maximum TLS version for server connections.",
) )
def tls_clienthello(self, tls_clienthello: tls.ClientHelloData): def tls_clienthello(self, tls_clienthello: tls.ClientHelloData):
@ -163,15 +175,15 @@ class TlsConfig:
# don't assign to client.cipher_list, doesn't need to be stored. # don't assign to client.cipher_list, doesn't need to be stored.
cipher_list = server.cipher_list or DEFAULT_CIPHERS cipher_list = server.cipher_list or DEFAULT_CIPHERS
client_cert: Optional[Path] = None client_cert: Optional[str] = None
if ctx.options.client_certs: if ctx.options.client_certs:
client_certs = Path(ctx.options.client_certs).expanduser() client_certs = os.path.expanduser(ctx.options.client_certs)
if client_certs.is_file(): if os.path.isfile(client_certs):
client_cert = client_certs client_cert = client_certs
else: else:
server_name: str = (server.sni or server.address[0].encode("idna")).decode() server_name: str = (server.sni or server.address[0].encode("idna")).decode()
p = (client_certs / f"{server_name}.pem") p = os.path.join(client_certs, f"{server_name}.pem")
if p.is_file(): if os.path.isfile(p):
client_cert = p client_cert = p
ssl_ctx = net_tls.create_proxy_server_context( ssl_ctx = net_tls.create_proxy_server_context(

View File

@ -4,8 +4,9 @@ import time
import datetime import datetime
import ipaddress import ipaddress
import sys import sys
import typing
import contextlib import contextlib
from pathlib import Path
from typing import Tuple, Optional, Union, Dict, List
from pyasn1.type import univ, constraint, char, namedtype, tag from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode from pyasn1.codec.der.decoder import decode
@ -36,7 +37,7 @@ rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI=
""" """
def create_ca(organization, cn, exp, key_size): def create_ca(organization: str, cn: str, exp: int, key_size: int) -> Tuple[OpenSSL.crypto.PKey, OpenSSL.crypto.X509]:
key = OpenSSL.crypto.PKey() key = OpenSSL.crypto.PKey()
key.generate_key(OpenSSL.crypto.TYPE_RSA, key_size) key.generate_key(OpenSSL.crypto.TYPE_RSA, key_size)
cert = OpenSSL.crypto.X509() cert = OpenSSL.crypto.X509()
@ -145,8 +146,8 @@ class CertStoreEntry:
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs) TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans) TGeneratedCertId = Tuple[Optional[bytes], Tuple[bytes, ...]] # (common_name, sans)
TCertId = typing.Union[TCustomCertId, TGeneratedCertId] TCertId = Union[TCustomCertId, TGeneratedCertId]
class CertStore: class CertStore:
@ -166,7 +167,7 @@ class CertStore:
self.default_ca = default_ca self.default_ca = default_ca
self.default_chain_file = default_chain_file self.default_chain_file = default_chain_file
self.dhparams = dhparams self.dhparams = dhparams
self.certs: typing.Dict[TCertId, CertStoreEntry] = {} self.certs: Dict[TCertId, CertStoreEntry] = {}
self.expire_queue = [] self.expire_queue = []
def expire(self, entry): def expire(self, entry):
@ -176,7 +177,7 @@ class CertStore:
self.certs = {k: v for k, v in self.certs.items() if v != d} self.certs = {k: v for k, v in self.certs.items() if v != d}
@staticmethod @staticmethod
def load_dhparam(path): def load_dhparam(path: str):
# mitmproxy<=0.10 doesn't generate a dhparam file. # mitmproxy<=0.10 doesn't generate a dhparam file.
# Create it now if necessary. # Create it now if necessary.
@ -196,23 +197,28 @@ class CertStore:
return dh return dh
@classmethod @classmethod
def from_store(cls, path, basename, key_size, passphrase: typing.Optional[bytes] = None): def from_store(cls, path: Union[Path, str], basename: str, key_size, passphrase: Optional[bytes] = None) -> "CertStore":
ca_path = os.path.join(path, basename + "-ca.pem") path = Path(path)
if not os.path.exists(ca_path): ca_file = path / f"{basename}-ca.pem"
key, ca = cls.create_store(path, basename, key_size) dhparam_file = path / f"{basename}-dhparam.pem"
else: if not ca_file.exists():
with open(ca_path, "rb") as f: cls.create_store(path, basename, key_size)
raw = f.read() return cls.from_files(ca_file, dhparam_file, passphrase)
@classmethod
def from_files(cls, ca_file: Path, dhparam_file: Path, passphrase: Optional[bytes] = None) -> "CertStore":
raw = ca_file.read_bytes()
ca = OpenSSL.crypto.load_certificate( ca = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_PEM,
raw) raw
)
key = OpenSSL.crypto.load_privatekey( key = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_PEM,
raw, raw,
passphrase) passphrase
dh_path = os.path.join(path, basename + "-dhparam.pem") )
dh = cls.load_dhparam(dh_path) dh = cls.load_dhparam(str(dhparam_file))
return cls(key, ca, ca_path, dh) return cls(key, ca, str(ca_file), dh)
@staticmethod @staticmethod
@contextlib.contextmanager @contextlib.contextmanager
@ -230,16 +236,15 @@ class CertStore:
os.umask(original_umask) os.umask(original_umask)
@staticmethod @staticmethod
def create_store(path, basename, key_size, organization=None, cn=None, expiry=DEFAULT_EXP): def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None, expiry=DEFAULT_EXP) -> None:
if not os.path.exists(path): path.mkdir(parents=True, exist_ok=True)
os.makedirs(path)
organization = organization or basename organization = organization or basename
cn = cn or basename cn = cn or basename
key, ca = create_ca(organization=organization, cn=cn, exp=expiry, key_size=key_size) key, ca = create_ca(organization=organization, cn=cn, exp=expiry, key_size=key_size)
# Dump the CA plus private key # Dump the CA plus private key
with CertStore.umask_secret(), open(os.path.join(path, basename + "-ca.pem"), "wb") as f: with CertStore.umask_secret(), (path / f"{basename}-ca.pem").open("wb") as f:
f.write( f.write(
OpenSSL.crypto.dump_privatekey( OpenSSL.crypto.dump_privatekey(
OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_PEM,
@ -250,38 +255,36 @@ class CertStore:
ca)) ca))
# Dump the certificate in PEM format # Dump the certificate in PEM format
with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f: with (path / f"{basename}-ca-cert.pem").open("wb") as f:
f.write( f.write(
OpenSSL.crypto.dump_certificate( OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_PEM,
ca)) ca))
# Create a .cer file with the same contents for Android # Create a .cer file with the same contents for Android
with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f: with (path / f"{basename}-ca-cert.cer").open("wb") as f:
f.write( f.write(
OpenSSL.crypto.dump_certificate( OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM, OpenSSL.crypto.FILETYPE_PEM,
ca)) ca))
# Dump the certificate in PKCS12 format for Windows devices # Dump the certificate in PKCS12 format for Windows devices
with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f: with (path / f"{basename}-ca-cert.p12").open("wb") as f:
p12 = OpenSSL.crypto.PKCS12() p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca) p12.set_certificate(ca)
f.write(p12.export()) f.write(p12.export())
# Dump the certificate and key in a PKCS12 format for Windows devices # Dump the certificate and key in a PKCS12 format for Windows devices
with CertStore.umask_secret(), open(os.path.join(path, basename + "-ca.p12"), "wb") as f: with CertStore.umask_secret(), (path / f"{basename}-ca.p12").open("wb") as f:
p12 = OpenSSL.crypto.PKCS12() p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca) p12.set_certificate(ca)
p12.set_privatekey(key) p12.set_privatekey(key)
f.write(p12.export()) f.write(p12.export())
with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f: with (path / f"{basename}-dhparam.pem").open("wb") as f:
f.write(DEFAULT_DHPARAM) f.write(DEFAULT_DHPARAM)
return key, ca def add_cert_file(self, spec: str, path: str, passphrase: Optional[bytes] = None) -> None:
def add_cert_file(self, spec: str, path: str, passphrase: typing.Optional[bytes] = None) -> None:
with open(path, "rb") as f: with open(path, "rb") as f:
raw = f.read() raw = f.read()
cert = Cert( cert = Cert(
@ -313,7 +316,7 @@ class CertStore:
self.certs[i] = entry self.certs[i] = entry
@staticmethod @staticmethod
def asterisk_forms(dn: bytes) -> typing.List[bytes]: def asterisk_forms(dn: bytes) -> List[bytes]:
""" """
Return all asterisk forms for a domain. For example, for www.example.com this will return Return all asterisk forms for a domain. For example, for www.example.com this will return
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted. [b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
@ -326,10 +329,10 @@ class CertStore:
def get_cert( def get_cert(
self, self,
commonname: typing.Optional[bytes], commonname: Optional[bytes],
sans: typing.List[bytes], sans: List[bytes],
organization: typing.Optional[bytes] = None organization: Optional[bytes] = None
) -> typing.Tuple["Cert", OpenSSL.SSL.PKey, str]: ) -> Tuple["Cert", OpenSSL.SSL.PKey, str]:
""" """
Returns an (cert, privkey, cert_chain) tuple. Returns an (cert, privkey, cert_chain) tuple.
@ -341,7 +344,7 @@ class CertStore:
organization: Organization name for the generated certificate. organization: Organization name for the generated certificate.
""" """
potential_keys: typing.List[TCertId] = [] potential_keys: List[TCertId] = []
if commonname: if commonname:
potential_keys.extend(self.asterisk_forms(commonname)) potential_keys.extend(self.asterisk_forms(commonname))
for s in sans: for s in sans:
@ -467,7 +470,7 @@ class Cert(serializable.Serializable):
) )
@property @property
def cn(self) -> typing.Optional[bytes]: def cn(self) -> Optional[bytes]:
c = None c = None
for i in self.subject: for i in self.subject:
if i[0] == b"CN": if i[0] == b"CN":
@ -475,7 +478,7 @@ class Cert(serializable.Serializable):
return c return c
@property @property
def organization(self) -> typing.Optional[bytes]: def organization(self) -> Optional[bytes]:
c = None c = None
for i in self.subject: for i in self.subject:
if i[0] == b"O": if i[0] == b"O":
@ -483,7 +486,7 @@ class Cert(serializable.Serializable):
return c return c
@property @property
def altnames(self) -> typing.List[bytes]: def altnames(self) -> List[bytes]:
""" """
Returns: Returns:
All DNS altnames. All DNS altnames.

View File

@ -3,7 +3,7 @@ import os
import threading import threading
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Iterable, Callable, Optional, Tuple, List, Any from typing import Iterable, Callable, Optional, Tuple, List, Any, BinaryIO
import certifi import certifi
from kaitaistruct import KaitaiStream from kaitaistruct import KaitaiStream
@ -22,6 +22,11 @@ class Method(Enum):
# TODO: remove once https://github.com/pyca/pyopenssl/pull/985 has landed. # TODO: remove once https://github.com/pyca/pyopenssl/pull/985 has landed.
try:
SSL._lib.TLS_server_method
except AttributeError as e: # pragma: no cover
raise RuntimeError("Your installation of the cryptography Python package is outdated.") from e
SSL.Context._methods.setdefault(Method.TLS_SERVER_METHOD.value, SSL._lib.TLS_server_method) SSL.Context._methods.setdefault(Method.TLS_SERVER_METHOD.value, SSL._lib.TLS_server_method)
SSL.Context._methods.setdefault(Method.TLS_CLIENT_METHOD.value, SSL._lib.TLS_client_method) SSL.Context._methods.setdefault(Method.TLS_CLIENT_METHOD.value, SSL._lib.TLS_client_method)
@ -52,7 +57,7 @@ DEFAULT_OPTIONS = (
class MasterSecretLogger: class MasterSecretLogger:
def __init__(self, filename: Path): def __init__(self, filename: Path):
self.filename = filename.expanduser() self.filename = filename.expanduser()
self.f = None self.f: Optional[BinaryIO] = None
self.lock = threading.Lock() self.lock = threading.Lock()
# required for functools.wraps, which pyOpenSSL uses. # required for functools.wraps, which pyOpenSSL uses.
@ -89,7 +94,7 @@ def _create_ssl_context(
method: Method, method: Method,
min_version: Version, min_version: Version,
max_version: Version, max_version: Version,
cipher_list: List[str], cipher_list: Optional[Iterable[str]],
) -> SSL.Context: ) -> SSL.Context:
context = SSL.Context(method.value) context = SSL.Context(method.value)
@ -105,7 +110,7 @@ def _create_ssl_context(
context.set_options(DEFAULT_OPTIONS) context.set_options(DEFAULT_OPTIONS)
# Cipher List # Cipher List
if cipher_list: if cipher_list is not None:
try: try:
context.set_cipher_list(b":".join(x.encode() for x in cipher_list)) context.set_cipher_list(b":".join(x.encode() for x in cipher_list))
except SSL.Error as v: except SSL.Error as v:
@ -122,13 +127,13 @@ def create_proxy_server_context(
*, *,
min_version: Version, min_version: Version,
max_version: Version, max_version: Version,
cipher_list: List[str], cipher_list: Optional[Iterable[str]],
verify: Verify, verify: Verify,
sni: Optional[bytes], sni: Optional[bytes],
ca_path: Path, ca_path: Optional[str],
ca_pemfile: Path, ca_pemfile: Optional[str],
client_cert: Path, client_cert: Optional[str],
alpn_protos: Iterable[bytes], alpn_protos: Optional[Iterable[bytes]],
) -> SSL.Context: ) -> SSL.Context:
context: SSL.Context = _create_ssl_context( context: SSL.Context = _create_ssl_context(
method=Method.TLS_CLIENT_METHOD, method=Method.TLS_CLIENT_METHOD,
@ -165,8 +170,8 @@ def create_proxy_server_context(
# Client Certs # Client Certs
if client_cert: if client_cert:
try: try:
context.use_privatekey_file(str(client_cert)) context.use_privatekey_file(client_cert)
context.use_certificate_chain_file(str(client_cert)) context.use_certificate_chain_file(client_cert)
except SSL.Error as v: except SSL.Error as v:
raise exceptions.TlsException(f"TLS client certificate error: {v}") raise exceptions.TlsException(f"TLS client certificate error: {v}")
@ -181,11 +186,11 @@ def create_client_proxy_context(
*, *,
min_version: Version, min_version: Version,
max_version: Version, max_version: Version,
cipher_list: List[str], cipher_list: Optional[Iterable[str]],
cert: certs.Cert, cert: certs.Cert,
key: SSL.PKey, key: SSL.PKey,
chain_file: str, chain_file: str,
alpn_select_callback: Callable[[SSL.Connection, List[bytes]], Any], alpn_select_callback: Optional[Callable[[SSL.Connection, List[bytes]], Any]],
request_client_cert: bool, request_client_cert: bool,
extra_chain_certs: Iterable[certs.Cert], extra_chain_certs: Iterable[certs.Cert],
dhparams, dhparams,
@ -249,8 +254,9 @@ def is_tls_record_magic(d):
""" """
d = d[:3] d = d[:3]
# TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2 # TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2, and TLSv1.3
# http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello # http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello
# https://tls13.ulfheim.net/
return ( return (
len(d) == 3 and len(d) == 3 and
d[0] == 0x16 and d[0] == 0x16 and

View File

@ -1,3 +1,7 @@
from pathlib import Path
from OpenSSL import SSL
from mitmproxy import certs
from mitmproxy.net import tls from mitmproxy.net import tls
CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex( CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex(
@ -17,76 +21,59 @@ def test_make_master_secret_logger():
assert isinstance(tls.make_master_secret_logger("filepath"), tls.MasterSecretLogger) assert isinstance(tls.make_master_secret_logger("filepath"), tls.MasterSecretLogger)
"""
def test_sslkeylogfile(tdata, monkeypatch): def test_sslkeylogfile(tdata, monkeypatch):
keylog = [] keylog = []
monkeypatch.setattr(tls, "log_master_secret", lambda conn, secrets: keylog.append(secrets)) monkeypatch.setattr(tls, "log_master_secret", lambda conn, secrets: keylog.append(secrets))
ctx = tls.create_client_context() store = certs.CertStore.from_files(
Path(tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.pem")),
Path(tdata.path("mitmproxy/net/data/dhparam.pem"))
)
cert, key, chain_file = store.get_cert(b"example.com", [], None)
ta = tlsconfig.TlsConfig() cctx = tls.create_proxy_server_context(
with taddons.context(ta) as tctx: min_version=tls.DEFAULT_MIN_VERSION,
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options) max_version=tls.DEFAULT_MAX_VERSION,
ctx.server.address = ("example.mitmproxy.org", 443) cipher_list=None,
tctx.configure(ta, ssl_verify_upstream_trusted_ca=tdata.path( verify=tls.Verify.VERIFY_NONE,
"mitmproxy/net/data/verificationcerts/trusted-root.crt")) sni=None,
ca_path=None,
tls_start = tls.TlsStartData(ctx.server, context=ctx) ca_pemfile=None,
ta.tls_start(tls_start) client_cert=None,
tssl_client = tls_start.ssl_conn alpn_protos=(),
tssl_server = test_tls.SSLTest(server_side=True) )
assert self.do_handshake(tssl_client, tssl_server) sctx = tls.create_client_proxy_context(
""" min_version=tls.DEFAULT_MIN_VERSION,
max_version=tls.DEFAULT_MAX_VERSION,
""" cipher_list=None,
class TestMasterSecretLogger(tservers.ServerTestBase): cert=cert,
handler = EchoHandler key=key,
ssl = dict( chain_file=chain_file,
cipher_list="AES256-SHA" alpn_select_callback=None,
request_client_cert=False,
extra_chain_certs=(),
dhparams=store.dhparams,
) )
def test_log(self, tmpdir): server = SSL.Connection(sctx)
testval = b"echo!\n" server.set_accept_state()
_logfun = tls.log_master_secret
logfile = str(tmpdir.join("foo", "bar", "logfile")) client = SSL.Connection(cctx)
tls.log_master_secret = tls.MasterSecretLogger(logfile) client.set_connect_state()
c = TCPClient(("127.0.0.1", self.port)) read, write = client, server
with c.connect(): while True:
c.convert_to_tls() try:
c.wfile.write(testval) print(read)
c.wfile.flush() read.do_handshake()
assert c.rfile.readline() == testval except SSL.WantReadError:
c.finish() write.bio_write(read.bio_read(2 ** 16))
else:
break
read, write = write, read
tls.log_master_secret.close() assert keylog
with open(logfile, "rb") as f: assert keylog[0].startswith(b"SERVER_HANDSHAKE_TRAFFIC_SECRET")
assert f.read().count(b"SERVER_HANDSHAKE_TRAFFIC_SECRET") >= 2
tls.log_master_secret = _logfun
def test_create_logfun(self):
assert isinstance(
tls.MasterSecretLogger.create_logfun("test"),
tls.MasterSecretLogger)
assert not tls.MasterSecretLogger.create_logfun(False)
class TestTLSInvalid:
def test_invalid_ssl_method_should_fail(self):
fake_ssl_method = 100500
with pytest.raises(exceptions.TlsException):
tls.create_proxy_server_context(method=fake_ssl_method)
def test_alpn_error(self):
with pytest.raises(exceptions.TlsException, match="must be a function"):
tls.create_proxy_server_context(alpn_select_callback="foo")
with pytest.raises(exceptions.TlsException, match="ALPN error"):
tls.create_proxy_server_context(alpn_select="foo", alpn_select_callback="bar")
"""
def test_is_record_magic(): def test_is_record_magic():

View File

@ -206,7 +206,7 @@ class TestCert:
assert x == c assert x == c
def test_from_store_with_passphrase(self, tdata, tmpdir): def test_from_store_with_passphrase(self, tdata, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, "password") ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, b"password")
ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), "password") ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), b"password")
assert ca.get_cert(b"foo", []) assert ca.get_cert(b"foo", [])