mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-29 11:03:13 +00:00
add type annotations, test sslkeylogfile
This commit is contained in:
parent
de46db53e9
commit
2db9a43fd6
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Tuple, TypedDict, Any
|
from typing import List, Optional, Tuple, TypedDict, Any
|
||||||
|
|
||||||
from OpenSSL import SSL, crypto
|
from OpenSSL import SSL, crypto
|
||||||
@ -61,20 +60,33 @@ class TlsConfig:
|
|||||||
# - ssl_verify_upstream_trusted_confdir
|
# - ssl_verify_upstream_trusted_confdir
|
||||||
|
|
||||||
def load(self, loader):
|
def load(self, loader):
|
||||||
for c in ["client", "server"]:
|
|
||||||
loader.add_option(
|
loader.add_option(
|
||||||
name=f"tls_version_{c}_min",
|
name="tls_version_client_min",
|
||||||
typespec=str,
|
typespec=str,
|
||||||
default=net_tls.DEFAULT_MIN_VERSION.name,
|
default=net_tls.DEFAULT_MIN_VERSION.name,
|
||||||
choices=[x.name for x in net_tls.Version],
|
choices=[x.name for x in net_tls.Version],
|
||||||
help=f"Set the minimum TLS version for {c} connections.",
|
help=f"Set the minimum TLS version for client connections.",
|
||||||
)
|
)
|
||||||
loader.add_option(
|
loader.add_option(
|
||||||
name=f"tls_version_{c}_max",
|
name="tls_version_client_max",
|
||||||
typespec=str,
|
typespec=str,
|
||||||
default=net_tls.DEFAULT_MAX_VERSION.name,
|
default=net_tls.DEFAULT_MAX_VERSION.name,
|
||||||
choices=[x.name for x in net_tls.Version],
|
choices=[x.name for x in net_tls.Version],
|
||||||
help=f"Set the maximum TLS version for {c} connections.",
|
help=f"Set the maximum TLS version for client connections.",
|
||||||
|
)
|
||||||
|
loader.add_option(
|
||||||
|
name="tls_version_server_min",
|
||||||
|
typespec=str,
|
||||||
|
default=net_tls.DEFAULT_MIN_VERSION.name,
|
||||||
|
choices=[x.name for x in net_tls.Version],
|
||||||
|
help=f"Set the minimum TLS version for server connections.",
|
||||||
|
)
|
||||||
|
loader.add_option(
|
||||||
|
name="tls_version_server_max",
|
||||||
|
typespec=str,
|
||||||
|
default=net_tls.DEFAULT_MAX_VERSION.name,
|
||||||
|
choices=[x.name for x in net_tls.Version],
|
||||||
|
help=f"Set the maximum TLS version for server connections.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def tls_clienthello(self, tls_clienthello: tls.ClientHelloData):
|
def tls_clienthello(self, tls_clienthello: tls.ClientHelloData):
|
||||||
@ -163,15 +175,15 @@ class TlsConfig:
|
|||||||
# don't assign to client.cipher_list, doesn't need to be stored.
|
# don't assign to client.cipher_list, doesn't need to be stored.
|
||||||
cipher_list = server.cipher_list or DEFAULT_CIPHERS
|
cipher_list = server.cipher_list or DEFAULT_CIPHERS
|
||||||
|
|
||||||
client_cert: Optional[Path] = None
|
client_cert: Optional[str] = None
|
||||||
if ctx.options.client_certs:
|
if ctx.options.client_certs:
|
||||||
client_certs = Path(ctx.options.client_certs).expanduser()
|
client_certs = os.path.expanduser(ctx.options.client_certs)
|
||||||
if client_certs.is_file():
|
if os.path.isfile(client_certs):
|
||||||
client_cert = client_certs
|
client_cert = client_certs
|
||||||
else:
|
else:
|
||||||
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
|
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
|
||||||
p = (client_certs / f"{server_name}.pem")
|
p = os.path.join(client_certs, f"{server_name}.pem")
|
||||||
if p.is_file():
|
if os.path.isfile(p):
|
||||||
client_cert = p
|
client_cert = p
|
||||||
|
|
||||||
ssl_ctx = net_tls.create_proxy_server_context(
|
ssl_ctx = net_tls.create_proxy_server_context(
|
||||||
|
@ -4,8 +4,9 @@ import time
|
|||||||
import datetime
|
import datetime
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import sys
|
import sys
|
||||||
import typing
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple, Optional, Union, Dict, List
|
||||||
|
|
||||||
from pyasn1.type import univ, constraint, char, namedtype, tag
|
from pyasn1.type import univ, constraint, char, namedtype, tag
|
||||||
from pyasn1.codec.der.decoder import decode
|
from pyasn1.codec.der.decoder import decode
|
||||||
@ -36,7 +37,7 @@ rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI=
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def create_ca(organization, cn, exp, key_size):
|
def create_ca(organization: str, cn: str, exp: int, key_size: int) -> Tuple[OpenSSL.crypto.PKey, OpenSSL.crypto.X509]:
|
||||||
key = OpenSSL.crypto.PKey()
|
key = OpenSSL.crypto.PKey()
|
||||||
key.generate_key(OpenSSL.crypto.TYPE_RSA, key_size)
|
key.generate_key(OpenSSL.crypto.TYPE_RSA, key_size)
|
||||||
cert = OpenSSL.crypto.X509()
|
cert = OpenSSL.crypto.X509()
|
||||||
@ -145,8 +146,8 @@ class CertStoreEntry:
|
|||||||
|
|
||||||
|
|
||||||
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
|
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
|
||||||
TGeneratedCertId = typing.Tuple[typing.Optional[bytes], typing.Tuple[bytes, ...]] # (common_name, sans)
|
TGeneratedCertId = Tuple[Optional[bytes], Tuple[bytes, ...]] # (common_name, sans)
|
||||||
TCertId = typing.Union[TCustomCertId, TGeneratedCertId]
|
TCertId = Union[TCustomCertId, TGeneratedCertId]
|
||||||
|
|
||||||
|
|
||||||
class CertStore:
|
class CertStore:
|
||||||
@ -166,7 +167,7 @@ class CertStore:
|
|||||||
self.default_ca = default_ca
|
self.default_ca = default_ca
|
||||||
self.default_chain_file = default_chain_file
|
self.default_chain_file = default_chain_file
|
||||||
self.dhparams = dhparams
|
self.dhparams = dhparams
|
||||||
self.certs: typing.Dict[TCertId, CertStoreEntry] = {}
|
self.certs: Dict[TCertId, CertStoreEntry] = {}
|
||||||
self.expire_queue = []
|
self.expire_queue = []
|
||||||
|
|
||||||
def expire(self, entry):
|
def expire(self, entry):
|
||||||
@ -176,7 +177,7 @@ class CertStore:
|
|||||||
self.certs = {k: v for k, v in self.certs.items() if v != d}
|
self.certs = {k: v for k, v in self.certs.items() if v != d}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_dhparam(path):
|
def load_dhparam(path: str):
|
||||||
|
|
||||||
# mitmproxy<=0.10 doesn't generate a dhparam file.
|
# mitmproxy<=0.10 doesn't generate a dhparam file.
|
||||||
# Create it now if necessary.
|
# Create it now if necessary.
|
||||||
@ -196,23 +197,28 @@ class CertStore:
|
|||||||
return dh
|
return dh
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_store(cls, path, basename, key_size, passphrase: typing.Optional[bytes] = None):
|
def from_store(cls, path: Union[Path, str], basename: str, key_size, passphrase: Optional[bytes] = None) -> "CertStore":
|
||||||
ca_path = os.path.join(path, basename + "-ca.pem")
|
path = Path(path)
|
||||||
if not os.path.exists(ca_path):
|
ca_file = path / f"{basename}-ca.pem"
|
||||||
key, ca = cls.create_store(path, basename, key_size)
|
dhparam_file = path / f"{basename}-dhparam.pem"
|
||||||
else:
|
if not ca_file.exists():
|
||||||
with open(ca_path, "rb") as f:
|
cls.create_store(path, basename, key_size)
|
||||||
raw = f.read()
|
return cls.from_files(ca_file, dhparam_file, passphrase)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_files(cls, ca_file: Path, dhparam_file: Path, passphrase: Optional[bytes] = None) -> "CertStore":
|
||||||
|
raw = ca_file.read_bytes()
|
||||||
ca = OpenSSL.crypto.load_certificate(
|
ca = OpenSSL.crypto.load_certificate(
|
||||||
OpenSSL.crypto.FILETYPE_PEM,
|
OpenSSL.crypto.FILETYPE_PEM,
|
||||||
raw)
|
raw
|
||||||
|
)
|
||||||
key = OpenSSL.crypto.load_privatekey(
|
key = OpenSSL.crypto.load_privatekey(
|
||||||
OpenSSL.crypto.FILETYPE_PEM,
|
OpenSSL.crypto.FILETYPE_PEM,
|
||||||
raw,
|
raw,
|
||||||
passphrase)
|
passphrase
|
||||||
dh_path = os.path.join(path, basename + "-dhparam.pem")
|
)
|
||||||
dh = cls.load_dhparam(dh_path)
|
dh = cls.load_dhparam(str(dhparam_file))
|
||||||
return cls(key, ca, ca_path, dh)
|
return cls(key, ca, str(ca_file), dh)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@ -230,16 +236,15 @@ class CertStore:
|
|||||||
os.umask(original_umask)
|
os.umask(original_umask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_store(path, basename, key_size, organization=None, cn=None, expiry=DEFAULT_EXP):
|
def create_store(path: Path, basename: str, key_size: int, organization=None, cn=None, expiry=DEFAULT_EXP) -> None:
|
||||||
if not os.path.exists(path):
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
os.makedirs(path)
|
|
||||||
|
|
||||||
organization = organization or basename
|
organization = organization or basename
|
||||||
cn = cn or basename
|
cn = cn or basename
|
||||||
|
|
||||||
key, ca = create_ca(organization=organization, cn=cn, exp=expiry, key_size=key_size)
|
key, ca = create_ca(organization=organization, cn=cn, exp=expiry, key_size=key_size)
|
||||||
# Dump the CA plus private key
|
# Dump the CA plus private key
|
||||||
with CertStore.umask_secret(), open(os.path.join(path, basename + "-ca.pem"), "wb") as f:
|
with CertStore.umask_secret(), (path / f"{basename}-ca.pem").open("wb") as f:
|
||||||
f.write(
|
f.write(
|
||||||
OpenSSL.crypto.dump_privatekey(
|
OpenSSL.crypto.dump_privatekey(
|
||||||
OpenSSL.crypto.FILETYPE_PEM,
|
OpenSSL.crypto.FILETYPE_PEM,
|
||||||
@ -250,38 +255,36 @@ class CertStore:
|
|||||||
ca))
|
ca))
|
||||||
|
|
||||||
# Dump the certificate in PEM format
|
# Dump the certificate in PEM format
|
||||||
with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f:
|
with (path / f"{basename}-ca-cert.pem").open("wb") as f:
|
||||||
f.write(
|
f.write(
|
||||||
OpenSSL.crypto.dump_certificate(
|
OpenSSL.crypto.dump_certificate(
|
||||||
OpenSSL.crypto.FILETYPE_PEM,
|
OpenSSL.crypto.FILETYPE_PEM,
|
||||||
ca))
|
ca))
|
||||||
|
|
||||||
# Create a .cer file with the same contents for Android
|
# Create a .cer file with the same contents for Android
|
||||||
with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f:
|
with (path / f"{basename}-ca-cert.cer").open("wb") as f:
|
||||||
f.write(
|
f.write(
|
||||||
OpenSSL.crypto.dump_certificate(
|
OpenSSL.crypto.dump_certificate(
|
||||||
OpenSSL.crypto.FILETYPE_PEM,
|
OpenSSL.crypto.FILETYPE_PEM,
|
||||||
ca))
|
ca))
|
||||||
|
|
||||||
# Dump the certificate in PKCS12 format for Windows devices
|
# Dump the certificate in PKCS12 format for Windows devices
|
||||||
with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f:
|
with (path / f"{basename}-ca-cert.p12").open("wb") as f:
|
||||||
p12 = OpenSSL.crypto.PKCS12()
|
p12 = OpenSSL.crypto.PKCS12()
|
||||||
p12.set_certificate(ca)
|
p12.set_certificate(ca)
|
||||||
f.write(p12.export())
|
f.write(p12.export())
|
||||||
|
|
||||||
# Dump the certificate and key in a PKCS12 format for Windows devices
|
# Dump the certificate and key in a PKCS12 format for Windows devices
|
||||||
with CertStore.umask_secret(), open(os.path.join(path, basename + "-ca.p12"), "wb") as f:
|
with CertStore.umask_secret(), (path / f"{basename}-ca.p12").open("wb") as f:
|
||||||
p12 = OpenSSL.crypto.PKCS12()
|
p12 = OpenSSL.crypto.PKCS12()
|
||||||
p12.set_certificate(ca)
|
p12.set_certificate(ca)
|
||||||
p12.set_privatekey(key)
|
p12.set_privatekey(key)
|
||||||
f.write(p12.export())
|
f.write(p12.export())
|
||||||
|
|
||||||
with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f:
|
with (path / f"{basename}-dhparam.pem").open("wb") as f:
|
||||||
f.write(DEFAULT_DHPARAM)
|
f.write(DEFAULT_DHPARAM)
|
||||||
|
|
||||||
return key, ca
|
def add_cert_file(self, spec: str, path: str, passphrase: Optional[bytes] = None) -> None:
|
||||||
|
|
||||||
def add_cert_file(self, spec: str, path: str, passphrase: typing.Optional[bytes] = None) -> None:
|
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
raw = f.read()
|
raw = f.read()
|
||||||
cert = Cert(
|
cert = Cert(
|
||||||
@ -313,7 +316,7 @@ class CertStore:
|
|||||||
self.certs[i] = entry
|
self.certs[i] = entry
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def asterisk_forms(dn: bytes) -> typing.List[bytes]:
|
def asterisk_forms(dn: bytes) -> List[bytes]:
|
||||||
"""
|
"""
|
||||||
Return all asterisk forms for a domain. For example, for www.example.com this will return
|
Return all asterisk forms for a domain. For example, for www.example.com this will return
|
||||||
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
|
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
|
||||||
@ -326,10 +329,10 @@ class CertStore:
|
|||||||
|
|
||||||
def get_cert(
|
def get_cert(
|
||||||
self,
|
self,
|
||||||
commonname: typing.Optional[bytes],
|
commonname: Optional[bytes],
|
||||||
sans: typing.List[bytes],
|
sans: List[bytes],
|
||||||
organization: typing.Optional[bytes] = None
|
organization: Optional[bytes] = None
|
||||||
) -> typing.Tuple["Cert", OpenSSL.SSL.PKey, str]:
|
) -> Tuple["Cert", OpenSSL.SSL.PKey, str]:
|
||||||
"""
|
"""
|
||||||
Returns an (cert, privkey, cert_chain) tuple.
|
Returns an (cert, privkey, cert_chain) tuple.
|
||||||
|
|
||||||
@ -341,7 +344,7 @@ class CertStore:
|
|||||||
organization: Organization name for the generated certificate.
|
organization: Organization name for the generated certificate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
potential_keys: typing.List[TCertId] = []
|
potential_keys: List[TCertId] = []
|
||||||
if commonname:
|
if commonname:
|
||||||
potential_keys.extend(self.asterisk_forms(commonname))
|
potential_keys.extend(self.asterisk_forms(commonname))
|
||||||
for s in sans:
|
for s in sans:
|
||||||
@ -467,7 +470,7 @@ class Cert(serializable.Serializable):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cn(self) -> typing.Optional[bytes]:
|
def cn(self) -> Optional[bytes]:
|
||||||
c = None
|
c = None
|
||||||
for i in self.subject:
|
for i in self.subject:
|
||||||
if i[0] == b"CN":
|
if i[0] == b"CN":
|
||||||
@ -475,7 +478,7 @@ class Cert(serializable.Serializable):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def organization(self) -> typing.Optional[bytes]:
|
def organization(self) -> Optional[bytes]:
|
||||||
c = None
|
c = None
|
||||||
for i in self.subject:
|
for i in self.subject:
|
||||||
if i[0] == b"O":
|
if i[0] == b"O":
|
||||||
@ -483,7 +486,7 @@ class Cert(serializable.Serializable):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def altnames(self) -> typing.List[bytes]:
|
def altnames(self) -> List[bytes]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
All DNS altnames.
|
All DNS altnames.
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Callable, Optional, Tuple, List, Any
|
from typing import Iterable, Callable, Optional, Tuple, List, Any, BinaryIO
|
||||||
|
|
||||||
import certifi
|
import certifi
|
||||||
from kaitaistruct import KaitaiStream
|
from kaitaistruct import KaitaiStream
|
||||||
@ -22,6 +22,11 @@ class Method(Enum):
|
|||||||
|
|
||||||
|
|
||||||
# TODO: remove once https://github.com/pyca/pyopenssl/pull/985 has landed.
|
# TODO: remove once https://github.com/pyca/pyopenssl/pull/985 has landed.
|
||||||
|
try:
|
||||||
|
SSL._lib.TLS_server_method
|
||||||
|
except AttributeError as e: # pragma: no cover
|
||||||
|
raise RuntimeError("Your installation of the cryptography Python package is outdated.") from e
|
||||||
|
|
||||||
SSL.Context._methods.setdefault(Method.TLS_SERVER_METHOD.value, SSL._lib.TLS_server_method)
|
SSL.Context._methods.setdefault(Method.TLS_SERVER_METHOD.value, SSL._lib.TLS_server_method)
|
||||||
SSL.Context._methods.setdefault(Method.TLS_CLIENT_METHOD.value, SSL._lib.TLS_client_method)
|
SSL.Context._methods.setdefault(Method.TLS_CLIENT_METHOD.value, SSL._lib.TLS_client_method)
|
||||||
|
|
||||||
@ -52,7 +57,7 @@ DEFAULT_OPTIONS = (
|
|||||||
class MasterSecretLogger:
|
class MasterSecretLogger:
|
||||||
def __init__(self, filename: Path):
|
def __init__(self, filename: Path):
|
||||||
self.filename = filename.expanduser()
|
self.filename = filename.expanduser()
|
||||||
self.f = None
|
self.f: Optional[BinaryIO] = None
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
# required for functools.wraps, which pyOpenSSL uses.
|
# required for functools.wraps, which pyOpenSSL uses.
|
||||||
@ -89,7 +94,7 @@ def _create_ssl_context(
|
|||||||
method: Method,
|
method: Method,
|
||||||
min_version: Version,
|
min_version: Version,
|
||||||
max_version: Version,
|
max_version: Version,
|
||||||
cipher_list: List[str],
|
cipher_list: Optional[Iterable[str]],
|
||||||
) -> SSL.Context:
|
) -> SSL.Context:
|
||||||
context = SSL.Context(method.value)
|
context = SSL.Context(method.value)
|
||||||
|
|
||||||
@ -105,7 +110,7 @@ def _create_ssl_context(
|
|||||||
context.set_options(DEFAULT_OPTIONS)
|
context.set_options(DEFAULT_OPTIONS)
|
||||||
|
|
||||||
# Cipher List
|
# Cipher List
|
||||||
if cipher_list:
|
if cipher_list is not None:
|
||||||
try:
|
try:
|
||||||
context.set_cipher_list(b":".join(x.encode() for x in cipher_list))
|
context.set_cipher_list(b":".join(x.encode() for x in cipher_list))
|
||||||
except SSL.Error as v:
|
except SSL.Error as v:
|
||||||
@ -122,13 +127,13 @@ def create_proxy_server_context(
|
|||||||
*,
|
*,
|
||||||
min_version: Version,
|
min_version: Version,
|
||||||
max_version: Version,
|
max_version: Version,
|
||||||
cipher_list: List[str],
|
cipher_list: Optional[Iterable[str]],
|
||||||
verify: Verify,
|
verify: Verify,
|
||||||
sni: Optional[bytes],
|
sni: Optional[bytes],
|
||||||
ca_path: Path,
|
ca_path: Optional[str],
|
||||||
ca_pemfile: Path,
|
ca_pemfile: Optional[str],
|
||||||
client_cert: Path,
|
client_cert: Optional[str],
|
||||||
alpn_protos: Iterable[bytes],
|
alpn_protos: Optional[Iterable[bytes]],
|
||||||
) -> SSL.Context:
|
) -> SSL.Context:
|
||||||
context: SSL.Context = _create_ssl_context(
|
context: SSL.Context = _create_ssl_context(
|
||||||
method=Method.TLS_CLIENT_METHOD,
|
method=Method.TLS_CLIENT_METHOD,
|
||||||
@ -165,8 +170,8 @@ def create_proxy_server_context(
|
|||||||
# Client Certs
|
# Client Certs
|
||||||
if client_cert:
|
if client_cert:
|
||||||
try:
|
try:
|
||||||
context.use_privatekey_file(str(client_cert))
|
context.use_privatekey_file(client_cert)
|
||||||
context.use_certificate_chain_file(str(client_cert))
|
context.use_certificate_chain_file(client_cert)
|
||||||
except SSL.Error as v:
|
except SSL.Error as v:
|
||||||
raise exceptions.TlsException(f"TLS client certificate error: {v}")
|
raise exceptions.TlsException(f"TLS client certificate error: {v}")
|
||||||
|
|
||||||
@ -181,11 +186,11 @@ def create_client_proxy_context(
|
|||||||
*,
|
*,
|
||||||
min_version: Version,
|
min_version: Version,
|
||||||
max_version: Version,
|
max_version: Version,
|
||||||
cipher_list: List[str],
|
cipher_list: Optional[Iterable[str]],
|
||||||
cert: certs.Cert,
|
cert: certs.Cert,
|
||||||
key: SSL.PKey,
|
key: SSL.PKey,
|
||||||
chain_file: str,
|
chain_file: str,
|
||||||
alpn_select_callback: Callable[[SSL.Connection, List[bytes]], Any],
|
alpn_select_callback: Optional[Callable[[SSL.Connection, List[bytes]], Any]],
|
||||||
request_client_cert: bool,
|
request_client_cert: bool,
|
||||||
extra_chain_certs: Iterable[certs.Cert],
|
extra_chain_certs: Iterable[certs.Cert],
|
||||||
dhparams,
|
dhparams,
|
||||||
@ -249,8 +254,9 @@ def is_tls_record_magic(d):
|
|||||||
"""
|
"""
|
||||||
d = d[:3]
|
d = d[:3]
|
||||||
|
|
||||||
# TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2
|
# TLS ClientHello magic, works for SSLv3, TLSv1.0, TLSv1.1, TLSv1.2, and TLSv1.3
|
||||||
# http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello
|
# http://www.moserware.com/2009/06/first-few-milliseconds-of-https.html#client-hello
|
||||||
|
# https://tls13.ulfheim.net/
|
||||||
return (
|
return (
|
||||||
len(d) == 3 and
|
len(d) == 3 and
|
||||||
d[0] == 0x16 and
|
d[0] == 0x16 and
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from OpenSSL import SSL
|
||||||
|
from mitmproxy import certs
|
||||||
from mitmproxy.net import tls
|
from mitmproxy.net import tls
|
||||||
|
|
||||||
CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex(
|
CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex(
|
||||||
@ -17,76 +21,59 @@ def test_make_master_secret_logger():
|
|||||||
assert isinstance(tls.make_master_secret_logger("filepath"), tls.MasterSecretLogger)
|
assert isinstance(tls.make_master_secret_logger("filepath"), tls.MasterSecretLogger)
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
def test_sslkeylogfile(tdata, monkeypatch):
|
def test_sslkeylogfile(tdata, monkeypatch):
|
||||||
keylog = []
|
keylog = []
|
||||||
monkeypatch.setattr(tls, "log_master_secret", lambda conn, secrets: keylog.append(secrets))
|
monkeypatch.setattr(tls, "log_master_secret", lambda conn, secrets: keylog.append(secrets))
|
||||||
|
|
||||||
ctx = tls.create_client_context()
|
store = certs.CertStore.from_files(
|
||||||
|
Path(tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.pem")),
|
||||||
|
Path(tdata.path("mitmproxy/net/data/dhparam.pem"))
|
||||||
|
)
|
||||||
|
cert, key, chain_file = store.get_cert(b"example.com", [], None)
|
||||||
|
|
||||||
ta = tlsconfig.TlsConfig()
|
cctx = tls.create_proxy_server_context(
|
||||||
with taddons.context(ta) as tctx:
|
min_version=tls.DEFAULT_MIN_VERSION,
|
||||||
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
|
max_version=tls.DEFAULT_MAX_VERSION,
|
||||||
ctx.server.address = ("example.mitmproxy.org", 443)
|
cipher_list=None,
|
||||||
tctx.configure(ta, ssl_verify_upstream_trusted_ca=tdata.path(
|
verify=tls.Verify.VERIFY_NONE,
|
||||||
"mitmproxy/net/data/verificationcerts/trusted-root.crt"))
|
sni=None,
|
||||||
|
ca_path=None,
|
||||||
tls_start = tls.TlsStartData(ctx.server, context=ctx)
|
ca_pemfile=None,
|
||||||
ta.tls_start(tls_start)
|
client_cert=None,
|
||||||
tssl_client = tls_start.ssl_conn
|
alpn_protos=(),
|
||||||
tssl_server = test_tls.SSLTest(server_side=True)
|
)
|
||||||
assert self.do_handshake(tssl_client, tssl_server)
|
sctx = tls.create_client_proxy_context(
|
||||||
"""
|
min_version=tls.DEFAULT_MIN_VERSION,
|
||||||
|
max_version=tls.DEFAULT_MAX_VERSION,
|
||||||
"""
|
cipher_list=None,
|
||||||
class TestMasterSecretLogger(tservers.ServerTestBase):
|
cert=cert,
|
||||||
handler = EchoHandler
|
key=key,
|
||||||
ssl = dict(
|
chain_file=chain_file,
|
||||||
cipher_list="AES256-SHA"
|
alpn_select_callback=None,
|
||||||
|
request_client_cert=False,
|
||||||
|
extra_chain_certs=(),
|
||||||
|
dhparams=store.dhparams,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_log(self, tmpdir):
|
server = SSL.Connection(sctx)
|
||||||
testval = b"echo!\n"
|
server.set_accept_state()
|
||||||
_logfun = tls.log_master_secret
|
|
||||||
|
|
||||||
logfile = str(tmpdir.join("foo", "bar", "logfile"))
|
client = SSL.Connection(cctx)
|
||||||
tls.log_master_secret = tls.MasterSecretLogger(logfile)
|
client.set_connect_state()
|
||||||
|
|
||||||
c = TCPClient(("127.0.0.1", self.port))
|
read, write = client, server
|
||||||
with c.connect():
|
while True:
|
||||||
c.convert_to_tls()
|
try:
|
||||||
c.wfile.write(testval)
|
print(read)
|
||||||
c.wfile.flush()
|
read.do_handshake()
|
||||||
assert c.rfile.readline() == testval
|
except SSL.WantReadError:
|
||||||
c.finish()
|
write.bio_write(read.bio_read(2 ** 16))
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
read, write = write, read
|
||||||
|
|
||||||
tls.log_master_secret.close()
|
assert keylog
|
||||||
with open(logfile, "rb") as f:
|
assert keylog[0].startswith(b"SERVER_HANDSHAKE_TRAFFIC_SECRET")
|
||||||
assert f.read().count(b"SERVER_HANDSHAKE_TRAFFIC_SECRET") >= 2
|
|
||||||
|
|
||||||
tls.log_master_secret = _logfun
|
|
||||||
|
|
||||||
def test_create_logfun(self):
|
|
||||||
assert isinstance(
|
|
||||||
tls.MasterSecretLogger.create_logfun("test"),
|
|
||||||
tls.MasterSecretLogger)
|
|
||||||
assert not tls.MasterSecretLogger.create_logfun(False)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestTLSInvalid:
|
|
||||||
def test_invalid_ssl_method_should_fail(self):
|
|
||||||
fake_ssl_method = 100500
|
|
||||||
with pytest.raises(exceptions.TlsException):
|
|
||||||
tls.create_proxy_server_context(method=fake_ssl_method)
|
|
||||||
|
|
||||||
def test_alpn_error(self):
|
|
||||||
with pytest.raises(exceptions.TlsException, match="must be a function"):
|
|
||||||
tls.create_proxy_server_context(alpn_select_callback="foo")
|
|
||||||
|
|
||||||
with pytest.raises(exceptions.TlsException, match="ALPN error"):
|
|
||||||
tls.create_proxy_server_context(alpn_select="foo", alpn_select_callback="bar")
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_record_magic():
|
def test_is_record_magic():
|
||||||
|
@ -206,7 +206,7 @@ class TestCert:
|
|||||||
assert x == c
|
assert x == c
|
||||||
|
|
||||||
def test_from_store_with_passphrase(self, tdata, tmpdir):
|
def test_from_store_with_passphrase(self, tdata, tmpdir):
|
||||||
ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, "password")
|
ca = certs.CertStore.from_store(str(tmpdir), "mitmproxy", 2048, b"password")
|
||||||
ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), "password")
|
ca.add_cert_file("*", tdata.path("mitmproxy/data/mitmproxy.pem"), b"password")
|
||||||
|
|
||||||
assert ca.get_cert(b"foo", [])
|
assert ca.get_cert(b"foo", [])
|
||||||
|
Loading…
Reference in New Issue
Block a user