add type annotations, test sslkeylogfile

This commit is contained in:
Maximilian Hils 2020-12-28 09:44:37 +01:00
parent de46db53e9
commit 2db9a43fd6
5 changed files with 149 additions and 141 deletions

View File

@ -1,5 +1,4 @@
import os
from pathlib import Path
from typing import List, Optional, Tuple, TypedDict, Any
from OpenSSL import SSL, crypto
@ -61,20 +60,33 @@ class TlsConfig:
# - ssl_verify_upstream_trusted_confdir
def load(self, loader):
for c in ["client", "server"]:
loader.add_option(
name=f"tls_version_{c}_min",
name="tls_version_client_min",
typespec=str,
default=net_tls.DEFAULT_MIN_VERSION.name,
choices=[x.name for x in net_tls.Version],
help=f"Set the minimum TLS version for {c} connections.",
help=f"Set the minimum TLS version for client connections.",
)
loader.add_option(
name=f"tls_version_{c}_max",
name="tls_version_client_max",
typespec=str,
default=net_tls.DEFAULT_MAX_VERSION.name,
choices=[x.name for x in net_tls.Version],
help=f"Set the maximum TLS version for {c} connections.",
help=f"Set the maximum TLS version for client connections.",
)
loader.add_option(
name="tls_version_server_min",
typespec=str,
default=net_tls.DEFAULT_MIN_VERSION.name,
choices=[x.name for x in net_tls.Version],
help=f"Set the minimum TLS version for server connections.",
)
loader.add_option(
name="tls_version_server_max",
typespec=str,
default=net_tls.DEFAULT_MAX_VERSION.name,
choices=[x.name for x in net_tls.Version],
help=f"Set the maximum TLS version for server connections.",
)
def tls_clienthello(self, tls_clienthello: tls.ClientHelloData):
@ -163,15 +175,15 @@ class TlsConfig:
# don't assign to client.cipher_list, doesn't need to be stored.
cipher_list = server.cipher_list or DEFAULT_CIPHERS
client_cert: Optional[Path] = None
client_cert: Optional[str] = None
if ctx.options.client_certs:
client_certs = Path(ctx.options.client_certs).expanduser()
if client_certs.is_file():
client_certs = os.path.expanduser(ctx.options.client_certs)
if os.path.isfile(client_certs):
client_cert = client_certs
else:
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
p = (client_certs / f"{server_name}.pem")
if p.is_file():
p = os.path.join(client_certs, f"{server_name}.pem")
if os.path.isfile(p):
client_cert = p
ssl_ctx = net_tls.create_proxy_server_context(

View File

@ -4,8 +4,9 @@ import time
import datetime
import ipaddress
import sys
import typing
import contextlib
from pathlib import Path
from typing import Tuple, Optional, Union, Dict, List
from pyasn1.type import univ, constraint, char, namedtype, tag
from pyasn1.codec.der.decoder import decode
@ -36,7 +37,7 @@ rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI=
"""
def create_ca(organization, cn, exp, key_size):
def create_ca(organization: str, cn: str, exp: int, key_size: int) -> Tuple[OpenSSL.crypto.PKey, OpenSSL.crypto.X509]:
key = OpenSSL.crypto.PKey()
key.generate_key(OpenSSL.crypto.TYPE_RSA, key_size)
cert = OpenSSL.crypto.X509()
@ -145,8 +146,8 @@ class CertStoreEntry:
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans)
TCertId = typing.Union[TCustomCertId, TGeneratedCertId]
TGeneratedCertId = Tuple[Optional[bytes], Tuple[bytes, ...]] # (common_name, sans)
TCertId = Union[TCustomCertId, TGeneratedCertId]
class CertStore:
@ -166,7 +167,7 @@ class CertStore:
self.default_ca = default_ca
self.default_chain_file = default_chain_file
self.dhparams = dhparams
self.certs: typing.Dict[TCertId, CertStoreEntry] = {}
self.certs: Dict[TCertId, CertStoreEntry] = {}
self.expire_queue = []
def expire(self, entry):
@ -176,7 +177,7 @@ class CertStore:
self.certs = {k: v for k, v in self.certs.items() if v != d}
@staticmethod
def load_dhparam(path):
def load_dhparam(path: str):
# mitmproxy<=0.10 doesn't generate a dhparam file.
# Create it now if necessary.
@ -196,23 +197,28 @@ class CertStore:
return dh
@classmethod
def from_store(cls, path, basename, key_size, passphrase: typing.Optional[bytes] = None):
ca_path = os.path.join(path, basename + "-ca.pem")
if not os.path.exists(ca_path):
key, ca = cls.create_store(path, basename, key_size)
else:
with open(ca_path, "rb") as f:
raw = f.read()
def from_store(cls, path: Union[Path, str], basename: str, key_size, passphrase: Optional[bytes] = None) -> "CertStore":
path = Path(path)
ca_file = path / f"{basename}-ca.pem"
dhparam_file = path / f"{basename}-dhparam.pem"
if not ca_file.exists():
cls.create_store(path, basename, key_size)
return cls.from_files(ca_file, dhparam_file, passphrase)
@classmethod
def from_files(cls, ca_file: Path, dhparam_file: Path, passphrase: Optional[bytes] = None) -> "CertStore":
raw = ca_file.read_bytes()
ca = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_PEM,
raw)
raw
)
key = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
raw,
passphrase)
dh_path = os.path.join(path, basename + "-dhparam.pem")
dh = cls.load_dhparam(dh_path)
return cls(key, ca, ca_path, dh)
passphrase
)
dh = cls.load_dhparam(str(dhparam_file))
return cls(key, ca, str(ca_file), dh)
@staticmethod
@contextlib.contextmanager
@ -230,16 +236,15 @@ class CertStore:
os.umask(original_umask)
@staticmethod
def create_store(path, basename, key_size, organization=None, cn=None, expiry=DEFAULT_EXP):
if not os.path.exists(path):
os.makedirs(path)
def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None, expiry=DEFAULT_EXP) -> None:
path.mkdir(parents=True, exist_ok=True)
organization = organization or basename
cn = cn or basename
key, ca = create_ca(organization=organization, cn=cn, exp=expiry, key_size=key_size)
# Dump the CA plus private key
with CertStore.umask_secret(), open(os.path.join(path, basename + "-ca.pem"), "wb") as f:
with CertStore.umask_secret(), (path / f"{basename}-ca.pem").open("wb") as f:
f.write(
OpenSSL.crypto.dump_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
@ -250,38 +255,36 @@ class CertStore:
ca))
# Dump the certificate in PEM format
with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f:
with (path / f"{basename}-ca-cert.pem").open("wb") as f:
f.write(
OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM,
ca))
# Create a .cer file with the same contents for Android
with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f:
with (path / f"{basename}-ca-cert.cer").open("wb") as f:
f.write(
OpenSSL.crypto.dump_certificate(
OpenSSL.crypto.FILETYPE_PEM,
ca))
# Dump the certificate in PKCS12 format for Windows devices
with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f:
with (path / f"{basename}-ca-cert.p12").open("wb") as f:
p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca)
f.write(p12.export())
# Dump the certificate and key in a PKCS12 format for Windows devices
with CertStore.umask_secret(), open(os.path.join(path, basename + "-ca.p12"), "wb") as f:
with CertStore.umask_secret(), (path / f"{basename}-ca.p12").open("wb") as f:
p12 = OpenSSL.crypto.PKCS12()
p12.set_certificate(ca)
p12.set_privatekey(key)
f.write(p12.export())
with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f:
with (path / f"{basename}-dhparam.pem").open("wb") as f:
f.write(DEFAULT_DHPARAM)
return key, ca
def add_cert_file(self, spec: str, path: str, passphrase: typing.Optional[bytes] = None) -> None:
def add_cert_file(self, spec: str, path: str, passphrase: Optional[bytes] = None) -> None:
with open(path, "rb") as f:
raw = f.read()
cert = Cert(
@ -313,7 +316,7 @@ class CertStore:
self.certs[i] = entry
@staticmethod
def asterisk_forms(dn: bytes) -> typing.List[bytes]:
def asterisk_forms(dn: bytes) -> List[bytes]:
"""
Return all asterisk forms for a domain. For example, for www.example.com this will return
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
@ -326,10 +329,10 @@ class CertStore:
def get_cert(
self,
commonname: typing.Optional[bytes],
sans: typing.List[bytes],
organization: typing.Optional[bytes] = None
) -> typing.Tuple["Cert", OpenSSL.SSL.PKey, str]:
commonname: Optional[bytes],
sans: List[bytes],
organization: Optional[bytes] = None
) -> Tuple["Cert", OpenSSL.SSL.PKey, str]:
"""
Returns an (cert, privkey, cert_chain) tuple.
@ -341,7 +344,7 @@ class CertStore:
organization: Organization name for the generated certificate.
"""
potential_keys: typing.List[TCertId] = []
potential_keys: List[TCertId] = []
if commonname:
potential_keys.extend(self.asterisk_forms(commonname))
for s in sans:
@ -467,7 +470,7 @@ class Cert(serializable.Serializable):
)
@property
def cn(self) -> typing.Optional[bytes]:
def cn(self) -> Optional[bytes]:
c = None
for i in self.subject:
if i[0] == b"CN":
@ -475,7 +478,7 @@ class Cert(serializable.Serializable):
return c
@property
def organization(self) -> typing.Optional[bytes]:
def organization(self) -> Optional[bytes]:
c = None
for i in self.subject:
if i[0] == b"O":
@ -483,7 +486,7 @@ class Cert(serializable.Serializable):
return c
@property
def altnames(self) -> typing.List[bytes]:
def altnames(self) -> List[bytes]:
"""
Returns:
All DNS altnames.

View File

@ -3,7 +3,7 @@ import os
import threading
from enum import Enum
from pathlib import Path
from typing import Iterable, Callable, Optional, Tuple, List, Any
from typing import Iterable, Callable, Optional, Tuple, List, Any, BinaryIO
import certifi
from kaitaistruct import KaitaiStream
@ -22,6 +22,11 @@ class Method(Enum):
# TODO: remove once https://github.com/pyca/pyopenssl/pull/985 has landed.
try:
SSL._lib.TLS_server_method
except AttributeError as e: # pragma: no cover
raise RuntimeError("Your installation of the cryptography Python package is outdated.") from e
SSL.Context._methods.setdefault(Method.TLS_SERVER_METHOD.value, SSL._lib.TLS_server_method)
SSL.Context._methods.setdefault(Method.TLS_CLIENT_METHOD.value, SSL._lib.TLS_client_method)
@ -52,7 +57,7 @@ DEFAULT_OPTIONS = (
class MasterSecretLogger:
def __init__(self, filename: Path):
self.filename = filename.expanduser()
self.f = None
self.f: Optional[BinaryIO] = None
self.lock = threading.Lock()
# required for functools.wraps, which pyOpenSSL uses.
@ -89,7 +94,7 @@ def _create_ssl_context(
method: Method,
min_version: Version,
max_version: Version,
cipher_list: List[str],
cipher_list: Optional[Iterable[str]],
) -> SSL.Context:
context = SSL.Context(method.value)
@ -105,7 +110,7 @@ def _create_ssl_context(
context.set_options(DEFAULT_OPTIONS)
# Cipher List
if cipher_list:
if cipher_list is not None:
try:
context.set_cipher_list(b":".join(x.encode() for x in cipher_list))
except SSL.Error as v:
@ -122,13 +127,13 @@ def create_proxy_server_context(
*,
min_version: Version,
max_version: Version,
cipher_list: List[str],
cipher_list: Optional[Iterable[str]],
verify: Verify,
sni: Optional[bytes],
ca_path: Path,
ca_pemfile: Path,
client_cert: Path,
alpn_protos: Iterable[bytes],
ca_path: Optional[str],
ca_pemfile: Optional[str],
client_cert: Optional[str],
alpn_protos: Optional[Iterable[bytes]],
) -> SSL.Context:
context: SSL.Context = _create_ssl_context(
method=Method.TLS_CLIENT_METHOD,
@ -165,8 +170,8 @@ def create_proxy_server_context(
# Client Certs
if client_cert:
try:
context.use_privatekey_file(str(client_cert))
context.use_certificate_chain_file(str(client_cert))
context.use_privatekey_file(client_cert)
context.use_certificate_chain_file(client_cert)
except SSL.Error as v:
raise exceptions.TlsException(f"TLS client certificate error: {v}")
@ -181,11 +186,11 @@ def create_client_proxy_context(
*,
min_version: Version,
max_version: Version,
cipher_list: List[str],
cipher_list: Optional[Iterable[str]],
cert: certs.Cert,
key: SSL.PKey,
chain_file: str,
alpn_select_callback: Callable[[SSL.Connection, List[bytes]], Any],
alpn_select_callback: Optional[Callable[[SSL.Connection, List[bytes]], Any]],
request_client_cert: bool,
extra_chain_certs: Iterable[certs.Cert],
dhparams,
@ -249,8 +254,9 @@ def is_tls_record_magic(d):
"""
d = d[:3]
# TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2
# TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2, and TLSv1.3
# http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello
# https://tls13.ulfheim.net/
return (
len(d) == 3 and
d[0] == 0x16 and

View File

@ -1,3 +1,7 @@
from pathlib import Path
from OpenSSL import SSL
from mitmproxy import certs
from mitmproxy.net import tls
CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex(
@ -17,76 +21,59 @@ def test_make_master_secret_logger():
assert isinstance(tls.make_master_secret_logger("filepath"), tls.MasterSecretLogger)
"""
def test_sslkeylogfile(tdata, monkeypatch):
keylog = []
monkeypatch.setattr(tls, "log_master_secret", lambda conn, secrets: keylog.append(secrets))
ctx = tls.create_client_context()
store = certs.CertStore.from_files(
Path(tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.pem")),
Path(tdata.path("mitmproxy/net/data/dhparam.pem"))
)
cert, key, chain_file = store.get_cert(b"example.com", [], None)
ta = tlsconfig.TlsConfig()
with taddons.context(ta) as tctx:
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
ctx.server.address = ("example.mitmproxy.org", 443)
tctx.configure(ta, ssl_verify_upstream_trusted_ca=tdata.path(
"mitmproxy/net/data/verificationcerts/trusted-root.crt"))
tls_start = tls.TlsStartData(ctx.server, context=ctx)
ta.tls_start(tls_start)
tssl_client = tls_start.ssl_conn
tssl_server = test_tls.SSLTest(server_side=True)
assert self.do_handshake(tssl_client, tssl_server)
"""
"""
class TestMasterSecretLogger(tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
cipher_list="AES256-SHA"
cctx = tls.create_proxy_server_context(
min_version=tls.DEFAULT_MIN_VERSION,
max_version=tls.DEFAULT_MAX_VERSION,
cipher_list=None,
verify=tls.Verify.VERIFY_NONE,
sni=None,
ca_path=None,
ca_pemfile=None,
client_cert=None,
alpn_protos=(),
)
sctx = tls.create_client_proxy_context(
min_version=tls.DEFAULT_MIN_VERSION,
max_version=tls.DEFAULT_MAX_VERSION,
cipher_list=None,
cert=cert,
key=key,
chain_file=chain_file,
alpn_select_callback=None,
request_client_cert=False,
extra_chain_certs=(),
dhparams=store.dhparams,
)
def test_log(self, tmpdir):
testval = b"echo!\n"
_logfun = tls.log_master_secret
server = SSL.Connection(sctx)
server.set_accept_state()
logfile = str(tmpdir.join("foo", "bar", "logfile"))
tls.log_master_secret = tls.MasterSecretLogger(logfile)
client = SSL.Connection(cctx)
client.set_connect_state()
c = TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls()
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
c.finish()
read, write = client, server
while True:
try:
print(read)
read.do_handshake()
except SSL.WantReadError:
write.bio_write(read.bio_read(2 ** 16))
else:
break
read, write = write, read
tls.log_master_secret.close()
with open(logfile, "rb") as f:
assert f.read().count(b"SERVER_HANDSHAKE_TRAFFIC_SECRET") >= 2
tls.log_master_secret = _logfun
def test_create_logfun(self):
assert isinstance(
tls.MasterSecretLogger.create_logfun("test"),
tls.MasterSecretLogger)
assert not tls.MasterSecretLogger.create_logfun(False)
class TestTLSInvalid:
def test_invalid_ssl_method_should_fail(self):
fake_ssl_method = 100500
with pytest.raises(exceptions.TlsException):
tls.create_proxy_server_context(method=fake_ssl_method)
def test_alpn_error(self):
with pytest.raises(exceptions.TlsException, match="must be a function"):
tls.create_proxy_server_context(alpn_select_callback="foo")
with pytest.raises(exceptions.TlsException, match="ALPN error"):
tls.create_proxy_server_context(alpn_select="foo", alpn_select_callback="bar")
"""
assert keylog
assert keylog[0].startswith(b"SERVER_HANDSHAKE_TRAFFIC_SECRET")
def test_is_record_magic():

View File

@ -206,7 +206,7 @@ class TestCert:
assert x == c
def test_from_store_with_passphrase(self, tdata, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, "password")
ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), "password")
ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, b"password")
ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), b"password")
assert ca.get_cert(b"foo", [])