mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 02:10:59 +00:00
sni/alpn: bytes -> str
This commit is contained in:
parent
85c5275ec1
commit
abbe9eeb79
@ -66,7 +66,7 @@ class NextLayer:
|
||||
pass
|
||||
else:
|
||||
if sni:
|
||||
hostnames.append(sni.decode("idna"))
|
||||
hostnames.append(sni)
|
||||
|
||||
if not hostnames:
|
||||
return False
|
||||
|
@ -8,6 +8,7 @@ from mitmproxy.net import tls as net_tls
|
||||
from mitmproxy.options import CONF_BASENAME
|
||||
from mitmproxy.proxy import context
|
||||
from mitmproxy.proxy.layers import tls
|
||||
from mitmproxy.utils.strutils import always_bytes
|
||||
|
||||
# We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default.
|
||||
# https://ssl-config.mozilla.org/#config=old
|
||||
@ -35,7 +36,7 @@ def alpn_select_callback(conn: SSL.Connection, options: List[bytes]) -> Any:
|
||||
return server_alpn
|
||||
http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS
|
||||
for alpn in options: # client sends in order of preference, so we are nice and respect that.
|
||||
if alpn in http_alpns:
|
||||
if alpn.decode(errors="replace") in http_alpns:
|
||||
return alpn
|
||||
else:
|
||||
return SSL.NO_OVERLAPPING_PROTOCOLS
|
||||
@ -137,7 +138,7 @@ class TlsConfig:
|
||||
)
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||
tls_start.ssl_conn.set_app_data(AppData(
|
||||
server_alpn=server.alpn,
|
||||
server_alpn=always_bytes(server.alpn, "utf8", "replace"),
|
||||
http2=ctx.options.http2,
|
||||
))
|
||||
tls_start.ssl_conn.set_accept_state()
|
||||
@ -153,15 +154,15 @@ class TlsConfig:
|
||||
verify = net_tls.Verify.VERIFY_PEER
|
||||
|
||||
if server.sni is True:
|
||||
server.sni = client.sni or server.address[0].encode()
|
||||
sni = server.sni or None # make sure that false-y values are None
|
||||
server.sni = client.sni or server.address[0]
|
||||
sni: Optional[bytes] = server.sni.encode("ascii") if server.sni else None
|
||||
|
||||
if not server.alpn_offers:
|
||||
if client.alpn_offers:
|
||||
if ctx.options.http2:
|
||||
server.alpn_offers = tuple(client.alpn_offers)
|
||||
else:
|
||||
server.alpn_offers = tuple(x for x in client.alpn_offers if x != b"h2")
|
||||
server.alpn_offers = tuple(x for x in client.alpn_offers if x != "h2")
|
||||
elif client.tls_established:
|
||||
# We would perfectly support HTTP/1 -> HTTP/2, but we want to keep things on the same protocol version.
|
||||
# There are some edge cases where we want to mirror the regular server's behavior accurately,
|
||||
@ -171,6 +172,7 @@ class TlsConfig:
|
||||
server.alpn_offers = tls.HTTP_ALPNS
|
||||
else:
|
||||
server.alpn_offers = tls.HTTP1_ALPNS
|
||||
alpn_offers: List[bytes] = [alpn.encode() for alpn in server.alpn_offers]
|
||||
|
||||
if not server.cipher_list and ctx.options.ciphers_server:
|
||||
server.cipher_list = ctx.options.ciphers_server.split(":")
|
||||
@ -183,7 +185,7 @@ class TlsConfig:
|
||||
if os.path.isfile(client_certs):
|
||||
client_cert = client_certs
|
||||
else:
|
||||
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
|
||||
server_name: str = server.sni or server.address[0]
|
||||
p = os.path.join(client_certs, f"{server_name}.pem")
|
||||
if os.path.isfile(p):
|
||||
client_cert = p
|
||||
@ -197,11 +199,11 @@ class TlsConfig:
|
||||
ca_path=ctx.options.ssl_verify_upstream_trusted_confdir,
|
||||
ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca,
|
||||
client_cert=client_cert,
|
||||
alpn_protos=server.alpn_offers,
|
||||
alpn_protos=alpn_offers,
|
||||
)
|
||||
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||
tls_start.ssl_conn.set_tlsext_host_name(server.sni)
|
||||
tls_start.ssl_conn.set_tlsext_host_name(sni)
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
|
||||
def running(self):
|
||||
@ -250,8 +252,8 @@ class TlsConfig:
|
||||
This function determines the Common Name (CN), Subject Alternative Names (SANs) and Organization Name
|
||||
our certificate should have and then fetches a matching cert from the certstore.
|
||||
"""
|
||||
altnames: List[bytes] = []
|
||||
organization: Optional[bytes] = None
|
||||
altnames: List[str] = []
|
||||
organization: Optional[str] = None
|
||||
|
||||
# Use upstream certificate if available.
|
||||
if conn_context.server.certificate_list:
|
||||
@ -266,11 +268,11 @@ class TlsConfig:
|
||||
if conn_context.client.sni:
|
||||
altnames.append(conn_context.client.sni)
|
||||
elif conn_context.server.address:
|
||||
altnames.append(conn_context.server.address[0].encode("idna"))
|
||||
altnames.append(conn_context.server.address[0])
|
||||
|
||||
# As a last resort, add *something* so that we have a certificate to serve.
|
||||
if not altnames:
|
||||
altnames.append(b"mitmproxy")
|
||||
altnames.append("mitmproxy")
|
||||
|
||||
# only keep first occurrence of each hostname
|
||||
altnames = list(dict.fromkeys(altnames))
|
||||
|
@ -109,21 +109,21 @@ class Cert(serializable.Serializable):
|
||||
return public_key.__class__.__name__.replace("PublicKey", "").replace("_", ""), -1
|
||||
|
||||
@property
|
||||
def cn(self) -> Optional[bytes]: # TODO: make this return str
|
||||
def cn(self) -> Optional[str]:
|
||||
attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
|
||||
if attrs:
|
||||
return attrs[0].value.encode()
|
||||
return attrs[0].value
|
||||
return None
|
||||
|
||||
@property
|
||||
def organization(self) -> Optional[bytes]: # TODO: make this return str
|
||||
def organization(self) -> Optional[str]:
|
||||
attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.ORGANIZATION_NAME)
|
||||
if attrs:
|
||||
return attrs[0].value.encode()
|
||||
return attrs[0].value
|
||||
return None
|
||||
|
||||
@property
|
||||
def altnames(self) -> List[bytes]: # TODO: make this return str
|
||||
def altnames(self) -> List[str]:
|
||||
"""
|
||||
Get all SubjectAlternativeName DNS altnames.
|
||||
"""
|
||||
@ -133,9 +133,9 @@ class Cert(serializable.Serializable):
|
||||
return []
|
||||
else:
|
||||
return (
|
||||
[x.encode() for x in ext.get_values_for_type(x509.DNSName)]
|
||||
ext.get_values_for_type(x509.DNSName)
|
||||
+
|
||||
[str(x).encode() for x in ext.get_values_for_type(x509.IPAddress)]
|
||||
[str(x) for x in ext.get_values_for_type(x509.IPAddress)]
|
||||
)
|
||||
|
||||
|
||||
@ -191,9 +191,9 @@ def create_ca(
|
||||
def dummy_cert(
|
||||
privkey: rsa.RSAPrivateKey,
|
||||
cacert: x509.Certificate,
|
||||
commonname: Optional[bytes],
|
||||
sans: List[bytes],
|
||||
organization: Optional[bytes] = None,
|
||||
commonname: Optional[str],
|
||||
sans: List[str],
|
||||
organization: Optional[str] = None,
|
||||
) -> Cert:
|
||||
"""
|
||||
Generates a dummy certificate.
|
||||
@ -206,10 +206,6 @@ def dummy_cert(
|
||||
|
||||
Returns cert if operation succeeded, None if not.
|
||||
"""
|
||||
XX_commonname: Optional[str] = commonname.decode("idna") if commonname else None
|
||||
XX_organization: Optional[str] = organization.decode() if organization else None
|
||||
XX_sans: List[str] = [x.decode("ascii") for x in sans]
|
||||
|
||||
builder = x509.CertificateBuilder()
|
||||
builder = builder.issuer_name(cacert.subject)
|
||||
builder = builder.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False)
|
||||
@ -221,19 +217,19 @@ def dummy_cert(
|
||||
|
||||
subject = []
|
||||
is_valid_commonname = (
|
||||
XX_commonname is not None and len(XX_commonname) < 64
|
||||
commonname is not None and len(commonname) < 64
|
||||
)
|
||||
if is_valid_commonname:
|
||||
assert XX_commonname is not None
|
||||
subject.append(x509.NameAttribute(NameOID.COMMON_NAME, XX_commonname))
|
||||
if XX_organization is not None:
|
||||
assert XX_organization is not None
|
||||
subject.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, XX_organization))
|
||||
assert commonname is not None
|
||||
subject.append(x509.NameAttribute(NameOID.COMMON_NAME, commonname))
|
||||
if organization is not None:
|
||||
assert organization is not None
|
||||
subject.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization))
|
||||
builder = builder.subject_name(x509.Name(subject))
|
||||
builder = builder.serial_number(x509.random_serial_number())
|
||||
|
||||
ss: List[x509.GeneralName] = []
|
||||
for x in XX_sans:
|
||||
for x in sans:
|
||||
try:
|
||||
ip = ipaddress.ip_address(x)
|
||||
except ValueError:
|
||||
@ -253,8 +249,8 @@ class CertStoreEntry:
|
||||
chain_file: Optional[Path]
|
||||
|
||||
|
||||
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
|
||||
TGeneratedCertId = Tuple[Optional[bytes], Tuple[bytes, ...]] # (common_name, sans)
|
||||
TCustomCertId = str # manually provided certs (e.g. mitmproxy's --certs)
|
||||
TGeneratedCertId = Tuple[Optional[str], Tuple[str, ...]] # (common_name, sans)
|
||||
TCertId = Union[TCustomCertId, TGeneratedCertId]
|
||||
|
||||
DHParams = NewType("DHParams", bytes)
|
||||
@ -415,10 +411,10 @@ class CertStore:
|
||||
|
||||
self.add_cert(
|
||||
CertStoreEntry(cert, key, path),
|
||||
spec.encode("idna")
|
||||
spec
|
||||
)
|
||||
|
||||
def add_cert(self, entry: CertStoreEntry, *names: bytes) -> None:
|
||||
def add_cert(self, entry: CertStoreEntry, *names: str) -> None:
|
||||
"""
|
||||
Adds a cert to the certstore. We register the CN in the cert plus
|
||||
any SANs, and also the list of names provided as an argument.
|
||||
@ -431,22 +427,22 @@ class CertStore:
|
||||
self.certs[i] = entry
|
||||
|
||||
@staticmethod
|
||||
def asterisk_forms(dn: bytes) -> List[bytes]:
|
||||
def asterisk_forms(dn: str) -> List[str]:
|
||||
"""
|
||||
Return all asterisk forms for a domain. For example, for www.example.com this will return
|
||||
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
|
||||
"""
|
||||
parts = dn.split(b".")
|
||||
parts = dn.split(".")
|
||||
ret = [dn]
|
||||
for i in range(1, len(parts)):
|
||||
ret.append(b"*." + b".".join(parts[i:]))
|
||||
ret.append("*." + ".".join(parts[i:]))
|
||||
return ret
|
||||
|
||||
def get_cert(
|
||||
self,
|
||||
commonname: Optional[bytes],
|
||||
sans: List[bytes],
|
||||
organization: Optional[bytes] = None
|
||||
commonname: Optional[str],
|
||||
sans: List[str],
|
||||
organization: Optional[str] = None
|
||||
) -> CertStoreEntry:
|
||||
"""
|
||||
commonname: Common name for the generated certificate. Must be a
|
||||
@ -462,7 +458,7 @@ class CertStore:
|
||||
potential_keys.extend(self.asterisk_forms(commonname))
|
||||
for s in sans:
|
||||
potential_keys.extend(self.asterisk_forms(s))
|
||||
potential_keys.append(b"*")
|
||||
potential_keys.append("*")
|
||||
potential_keys.append((commonname, tuple(sans)))
|
||||
|
||||
name = next(
|
||||
|
@ -230,6 +230,25 @@ def convert_9_10(data):
|
||||
return data
|
||||
|
||||
|
||||
def convert_10_11(data):
|
||||
data["version"] = 11
|
||||
|
||||
def conv_conn(conn):
|
||||
conn["sni"] = strutils.always_str(conn["sni"], "ascii", "backslashreplace")
|
||||
conn["alpn"] = strutils.always_str(conn.pop("alpn_proto_negotiated"), "utf8", "backslashreplace")
|
||||
conn["alpn_offers"] = [
|
||||
strutils.always_str(alpn, "utf8", "backslashreplace")
|
||||
for alpn in (conn["alpn_offers"] or [])
|
||||
]
|
||||
|
||||
conv_conn(data["client_conn"])
|
||||
conv_conn(data["server_conn"])
|
||||
if data["server_conn"]["via"]:
|
||||
conv_conn(data["server_conn"]["via"])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _convert_dict_keys(o: Any) -> Any:
|
||||
if isinstance(o, dict):
|
||||
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
|
||||
@ -287,6 +306,7 @@ converters = {
|
||||
7: convert_7_8,
|
||||
8: convert_8_9,
|
||||
9: convert_9_10,
|
||||
10: convert_10_11,
|
||||
}
|
||||
|
||||
|
||||
|
@ -279,7 +279,7 @@ class ClientHello:
|
||||
return self._client_hello.cipher_suites.cipher_suites
|
||||
|
||||
@property
|
||||
def sni(self) -> Optional[bytes]:
|
||||
def sni(self) -> Optional[str]:
|
||||
if self._client_hello.extensions:
|
||||
for extension in self._client_hello.extensions.extensions:
|
||||
is_valid_sni_extension = (
|
||||
@ -289,15 +289,18 @@ class ClientHello:
|
||||
check.is_valid_host(extension.body.server_names[0].host_name)
|
||||
)
|
||||
if is_valid_sni_extension:
|
||||
return extension.body.server_names[0].host_name
|
||||
return extension.body.server_names[0].host_name.decode("ascii")
|
||||
return None
|
||||
|
||||
@property
|
||||
def alpn_protocols(self):
|
||||
def alpn_protocols(self) -> List[str]:
|
||||
if self._client_hello.extensions:
|
||||
for extension in self._client_hello.extensions.extensions:
|
||||
if extension.type == 0x10:
|
||||
return list(x.name for x in extension.body.alpn_protocols)
|
||||
try:
|
||||
return [x.name.decode() for x in extension.body.alpn_protocols]
|
||||
except UnicodeDecodeError:
|
||||
return []
|
||||
return []
|
||||
|
||||
@property
|
||||
|
@ -56,8 +56,8 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
|
||||
TLS version, with the exception of the end-entity certificate which
|
||||
MUST be first.
|
||||
"""
|
||||
alpn: Optional[bytes] = None
|
||||
alpn_offers: Sequence[bytes] = ()
|
||||
alpn: Optional[str] = None
|
||||
alpn_offers: Sequence[str] = ()
|
||||
|
||||
# we may want to add SSL_CIPHER_description here, but that's currently not exposed by cryptography
|
||||
cipher: Optional[str] = None
|
||||
@ -65,7 +65,7 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
|
||||
cipher_list: Sequence[str] = ()
|
||||
"""Ciphers accepted by the proxy server on this connection."""
|
||||
tls_version: Optional[str] = None
|
||||
sni: Union[bytes, Literal[True], None]
|
||||
sni: Union[str, Literal[True], None]
|
||||
|
||||
timestamp_end: Optional[float] = None
|
||||
"""Connection end timestamp"""
|
||||
@ -98,7 +98,9 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
|
||||
@property
|
||||
def alpn_proto_negotiated(self) -> Optional[bytes]: # pragma: no cover
|
||||
warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", DeprecationWarning)
|
||||
return self.alpn
|
||||
if self.alpn is not None:
|
||||
return self.alpn.encode()
|
||||
return None
|
||||
|
||||
|
||||
class Client(Connection):
|
||||
@ -110,7 +112,7 @@ class Client(Connection):
|
||||
"""TCP SYN received"""
|
||||
|
||||
mitmcert: Optional[certs.Cert] = None
|
||||
sni: Union[bytes, None] = None
|
||||
sni: Union[str, None] = None
|
||||
|
||||
def __init__(self, peername, sockname, timestamp_start):
|
||||
self.id = str(uuid.uuid4())
|
||||
@ -123,7 +125,7 @@ class Client(Connection):
|
||||
# This means we need to add all new fields to the old implementation.
|
||||
return {
|
||||
'address': self.peername,
|
||||
'alpn_proto_negotiated': self.alpn,
|
||||
'alpn': self.alpn,
|
||||
'cipher_name': self.cipher,
|
||||
'id': self.id,
|
||||
'mitmcert': self.mitmcert.get_state() if self.mitmcert is not None else None,
|
||||
@ -156,7 +158,7 @@ class Client(Connection):
|
||||
|
||||
def set_state(self, state):
|
||||
self.peername = tuple(state["address"]) if state["address"] else None
|
||||
self.alpn = state["alpn_proto_negotiated"]
|
||||
self.alpn = state["alpn"]
|
||||
self.cipher = state["cipher_name"]
|
||||
self.id = state["id"]
|
||||
self.sni = state["sni"]
|
||||
@ -217,7 +219,7 @@ class Server(Connection):
|
||||
timestamp_tcp_setup: Optional[float] = None
|
||||
"""TCP ACK received"""
|
||||
|
||||
sni: Union[bytes, Literal[True], None] = True
|
||||
sni: Union[str, Literal[True], None] = True
|
||||
"""True: client SNI, False: no SNI, bytes: custom value"""
|
||||
via: Optional[server_spec.ServerSpec] = None
|
||||
|
||||
@ -228,7 +230,7 @@ class Server(Connection):
|
||||
def get_state(self):
|
||||
return {
|
||||
'address': self.address,
|
||||
'alpn_proto_negotiated': self.alpn,
|
||||
'alpn': self.alpn,
|
||||
'id': self.id,
|
||||
'ip_address': self.peername,
|
||||
'sni': self.sni,
|
||||
@ -259,7 +261,7 @@ class Server(Connection):
|
||||
|
||||
def set_state(self, state):
|
||||
self.address = tuple(state["address"]) if state["address"] else None
|
||||
self.alpn = state["alpn_proto_negotiated"]
|
||||
self.alpn = state["alpn"]
|
||||
self.id = state["id"]
|
||||
self.peername = tuple(state["ip_address"]) if state["ip_address"] else None
|
||||
self.sni = state["sni"]
|
||||
|
@ -351,7 +351,7 @@ class HttpStream(layer.Layer):
|
||||
yield HttpErrorHook(self.flow)
|
||||
# For HTTP/2 we only want to kill the specific stream, for HTTP/1 we want to kill the connection
|
||||
# *without* sending an HTTP response (that could be achieved by the user by setting flow.response).
|
||||
if self.context.client.alpn == b"h2":
|
||||
if self.context.client.alpn == "h2":
|
||||
yield SendHttp(ResponseProtocolError(self.stream_id, "killed"), self.context.client)
|
||||
else:
|
||||
if self.context.client.state & ConnectionState.CAN_WRITE:
|
||||
@ -435,7 +435,7 @@ class HttpStream(layer.Layer):
|
||||
|
||||
stack = tunnel.LayerStack()
|
||||
if self.context.server.via.scheme == "https":
|
||||
http_proxy.sni = self.context.server.via.address[0].encode()
|
||||
http_proxy.sni = self.context.server.via.address[0]
|
||||
stack /= tls.ServerTLSLayer(self.context, http_proxy)
|
||||
stack /= _upstream_proxy.HttpUpstreamProxy(self.context, http_proxy, True)
|
||||
|
||||
@ -532,7 +532,7 @@ class HttpLayer(layer.Layer):
|
||||
self.command_sources = {}
|
||||
|
||||
http_conn: HttpConnection
|
||||
if self.context.client.alpn == b"h2":
|
||||
if self.context.client.alpn == "h2":
|
||||
http_conn = Http2Server(context.fork())
|
||||
else:
|
||||
http_conn = Http1Server(context.fork())
|
||||
@ -606,10 +606,10 @@ class HttpLayer(layer.Layer):
|
||||
for connection in self.connections:
|
||||
# see "tricky multiplexing edge case" in make_http_connection for an explanation
|
||||
conn_is_pending_or_h2 = (
|
||||
connection.alpn == b"h2"
|
||||
connection.alpn == "h2"
|
||||
or connection in self.waiting_for_establishment
|
||||
)
|
||||
h2_to_h1 = self.context.client.alpn == b"h2" and not conn_is_pending_or_h2
|
||||
h2_to_h1 = self.context.client.alpn == "h2" and not conn_is_pending_or_h2
|
||||
connection_suitable = (
|
||||
event.connection_spec_matches(connection)
|
||||
and not h2_to_h1
|
||||
@ -635,7 +635,7 @@ class HttpLayer(layer.Layer):
|
||||
|
||||
context.server = Server(event.address)
|
||||
if event.tls:
|
||||
context.server.sni = event.address[0].encode()
|
||||
context.server.sni = event.address[0]
|
||||
|
||||
if event.via:
|
||||
assert event.via.scheme in ("http", "https")
|
||||
@ -643,7 +643,7 @@ class HttpLayer(layer.Layer):
|
||||
|
||||
if event.via.scheme == "https":
|
||||
http_proxy.alpn_offers = tls.HTTP_ALPNS
|
||||
http_proxy.sni = event.via.address[0].encode()
|
||||
http_proxy.sni = event.via.address[0]
|
||||
stack /= tls.ServerTLSLayer(context, http_proxy)
|
||||
|
||||
send_connect = not (self.mode == HTTPMode.upstream and not event.tls)
|
||||
@ -679,7 +679,7 @@ class HttpLayer(layer.Layer):
|
||||
# that neither have a content-length specified nor a chunked transfer encoding.
|
||||
# We can't process these two flows to the same h1 connection as they would both have
|
||||
# "read until eof" semantics. The only workaround left is to open a separate connection for each flow.
|
||||
if not command.err and self.context.client.alpn == b"h2" and command.connection.alpn != b"h2":
|
||||
if not command.err and self.context.client.alpn == "h2" and command.connection.alpn != "h2":
|
||||
for cmd in waiting[1:]:
|
||||
yield from self.get_connection(cmd, reuse=False)
|
||||
break
|
||||
@ -695,7 +695,7 @@ class HttpClient(layer.Layer):
|
||||
err = yield commands.OpenConnection(self.context.server)
|
||||
if not err:
|
||||
child_layer: layer.Layer
|
||||
if self.context.server.alpn == b"h2":
|
||||
if self.context.server.alpn == "h2":
|
||||
child_layer = Http2Client(self.context)
|
||||
else:
|
||||
child_layer = Http1Client(self.context)
|
||||
|
@ -42,7 +42,7 @@ class ReverseProxy(DestinationKnown):
|
||||
|
||||
if spec.scheme not in ("http", "tcp"):
|
||||
if not self.context.options.keep_host_header:
|
||||
self.context.server.sni = spec.address[0].encode()
|
||||
self.context.server.sni = spec.address[0]
|
||||
self.child_layer = tls.ServerTLSLayer(self.context)
|
||||
else:
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
@ -91,8 +91,8 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
|
||||
return None
|
||||
|
||||
|
||||
HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
|
||||
HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS
|
||||
HTTP1_ALPNS = ("http/1.1", "http/1.0", "http/0.9")
|
||||
HTTP_ALPNS = ("h2",) + HTTP1_ALPNS
|
||||
|
||||
|
||||
# We need these classes as hooks can only have one argument at the moment.
|
||||
@ -183,6 +183,8 @@ class _TLSLayer(tunnel.TunnelLayer):
|
||||
err = f"OpenSSL {e!r}"
|
||||
return False, err
|
||||
else:
|
||||
# Here we set all attributes that are only known *after* the handshake.
|
||||
|
||||
# Get all peer certificates.
|
||||
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
|
||||
# If called on the client side, the stack also contains the peer's certificate; if called on the server
|
||||
@ -194,11 +196,8 @@ class _TLSLayer(tunnel.TunnelLayer):
|
||||
all_certs.insert(0, cert)
|
||||
|
||||
self.conn.timestamp_tls_setup = time.time()
|
||||
self.conn.sni = self.tls.get_servername()
|
||||
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
|
||||
self.conn.alpn = self.tls.get_alpn_proto_negotiated().decode()
|
||||
self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs]
|
||||
if isinstance(self.conn, context.Client):
|
||||
self.conn.mitmcert = certs.Cert.from_pyopenssl(self.tls.get_certificate())
|
||||
self.conn.cipher = self.tls.get_cipher_name()
|
||||
self.conn.tls_version = self.tls.get_protocol_version_name()
|
||||
if self.debug:
|
||||
@ -339,8 +338,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
|
||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||
if self.conn.sni:
|
||||
assert isinstance(self.conn.sni, bytes)
|
||||
dest = self.conn.sni.decode("idna")
|
||||
dest = self.conn.sni
|
||||
else:
|
||||
dest = human.format_address(self.context.server.address)
|
||||
if err.startswith("Cannot parse ClientHello"):
|
||||
|
@ -427,7 +427,8 @@ if __name__ == "__main__": # pragma: no cover
|
||||
tls_start.ssl_conn.set_accept_state()
|
||||
else:
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni)
|
||||
if tls_start.context.client.sni is not None:
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni.encode())
|
||||
|
||||
await SimpleConnectionHandler(reader, writer, opts, {
|
||||
"next_layer": next_layer,
|
||||
|
@ -158,7 +158,7 @@ def tclient_conn() -> context.Client:
|
||||
timestamp_end=946681206,
|
||||
sni="address",
|
||||
cipher_name="cipher",
|
||||
alpn_proto_negotiated=b"http/1.1",
|
||||
alpn="http/1.1",
|
||||
tls_version="TLSv1.2",
|
||||
tls_extensions=[(0x00, bytes.fromhex("000e00000b6578616d"))],
|
||||
state=0,
|
||||
@ -185,7 +185,7 @@ def tserver_conn() -> context.Server:
|
||||
timestamp_end=946681205,
|
||||
tls_established=True,
|
||||
sni="address",
|
||||
alpn_proto_negotiated=None,
|
||||
alpn=None,
|
||||
tls_version="TLSv1.2",
|
||||
via=None,
|
||||
state=0,
|
||||
|
@ -5,7 +5,6 @@ import mitmproxy.flow
|
||||
from mitmproxy import http
|
||||
from mitmproxy.tools.console import common, searchable
|
||||
from mitmproxy.utils import human
|
||||
from mitmproxy.utils import strutils
|
||||
|
||||
|
||||
def maybe_timestamp(base, attr):
|
||||
@ -40,23 +39,23 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
|
||||
text.append(urwid.Text([("head", "Metadata:")]))
|
||||
text.extend(common.format_keyvals(parts, indent=4))
|
||||
|
||||
if sc is not None and sc.ip_address:
|
||||
if sc is not None and sc.peername:
|
||||
text.append(urwid.Text([("head", "Server Connection:")]))
|
||||
parts = [
|
||||
("Address", human.format_address(sc.address)),
|
||||
]
|
||||
if sc.ip_address:
|
||||
parts.append(("Resolved Address", human.format_address(sc.ip_address)))
|
||||
if sc.peername:
|
||||
parts.append(("Resolved Address", human.format_address(sc.peername)))
|
||||
if resp:
|
||||
parts.append(("HTTP Version", resp.http_version))
|
||||
if sc.alpn_proto_negotiated:
|
||||
parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn_proto_negotiated)))
|
||||
if sc.alpn:
|
||||
parts.append(("ALPN", sc.alpn))
|
||||
|
||||
text.extend(
|
||||
common.format_keyvals(parts, indent=4)
|
||||
)
|
||||
|
||||
c = sc.cert
|
||||
c = sc.certificate_list[0]
|
||||
if c:
|
||||
text.append(urwid.Text([("head", "Server Certificate:")]))
|
||||
parts = [
|
||||
@ -65,39 +64,12 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
|
||||
("Valid to", str(c.notafter)),
|
||||
("Valid from", str(c.notbefore)),
|
||||
("Serial", str(c.serial)),
|
||||
(
|
||||
"Subject",
|
||||
urwid.BoxAdapter(
|
||||
urwid.ListBox(
|
||||
common.format_keyvals(
|
||||
c.subject,
|
||||
key_format="highlight"
|
||||
)
|
||||
),
|
||||
len(c.subject)
|
||||
)
|
||||
),
|
||||
(
|
||||
"Issuer",
|
||||
urwid.BoxAdapter(
|
||||
urwid.ListBox(
|
||||
common.format_keyvals(
|
||||
c.issuer,
|
||||
key_format="highlight"
|
||||
)
|
||||
),
|
||||
len(c.issuer)
|
||||
)
|
||||
)
|
||||
("Subject", urwid.Pile(common.format_keyvals(c.subject, key_format="highlight"))),
|
||||
("Issuer", urwid.Pile(common.format_keyvals(c.issuer, key_format="highlight")))
|
||||
]
|
||||
|
||||
if c.altnames:
|
||||
parts.append(
|
||||
(
|
||||
"Alt names",
|
||||
", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames)
|
||||
)
|
||||
)
|
||||
parts.append(("Alt names", ", ".join(c.altnames)))
|
||||
text.extend(
|
||||
common.format_keyvals(parts, indent=4)
|
||||
)
|
||||
@ -106,19 +78,18 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
|
||||
text.append(urwid.Text([("head", "Client Connection:")]))
|
||||
|
||||
parts = [
|
||||
("Address", "{}:{}".format(cc.address[0], cc.address[1])),
|
||||
("Address", human.format_address(cc.peername)),
|
||||
]
|
||||
if req:
|
||||
parts.append(("HTTP Version", req.http_version))
|
||||
if cc.tls_version:
|
||||
parts.append(("TLS Version", cc.tls_version))
|
||||
if cc.sni:
|
||||
parts.append(("Server Name Indication",
|
||||
strutils.bytes_to_escaped_str(strutils.always_bytes(cc.sni, "idna"))))
|
||||
if cc.cipher_name:
|
||||
parts.append(("Cipher Name", cc.cipher_name))
|
||||
if cc.alpn_proto_negotiated:
|
||||
parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn_proto_negotiated)))
|
||||
parts.append(("Server Name Indication", cc.sni))
|
||||
if cc.cipher:
|
||||
parts.append(("Cipher Name", cc.cipher))
|
||||
if cc.alpn:
|
||||
parts.append(("ALPN", cc.alpn))
|
||||
|
||||
text.extend(
|
||||
common.format_keyvals(parts, indent=4)
|
||||
|
@ -20,7 +20,6 @@ from mitmproxy import io
|
||||
from mitmproxy import log
|
||||
from mitmproxy import optmanager
|
||||
from mitmproxy import version
|
||||
from mitmproxy.utils.strutils import always_str
|
||||
|
||||
|
||||
def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
||||
@ -47,10 +46,9 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
||||
"timestamp_start": flow.client_conn.timestamp_start,
|
||||
"timestamp_tls_setup": flow.client_conn.timestamp_tls_setup,
|
||||
"timestamp_end": flow.client_conn.timestamp_end,
|
||||
# ideally idna, but we don't want errors
|
||||
"sni": always_str(flow.client_conn.sni, "ascii", "backslashreplace"),
|
||||
"sni": flow.client_conn.sni,
|
||||
"cipher_name": flow.client_conn.cipher,
|
||||
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
|
||||
"alpn_proto_negotiated": flow.client_conn.alpn,
|
||||
"tls_version": flow.client_conn.tls_version,
|
||||
}
|
||||
|
||||
@ -61,18 +59,14 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
||||
"ip_address": flow.server_conn.peername,
|
||||
"source_address": flow.server_conn.sockname,
|
||||
"tls_established": flow.server_conn.tls_established,
|
||||
"alpn_proto_negotiated": always_str(flow.server_conn.alpn, "ascii", "backslashreplace"),
|
||||
"sni": flow.server_conn.sni,
|
||||
"alpn_proto_negotiated": flow.server_conn.alpn,
|
||||
"tls_version": flow.server_conn.tls_version,
|
||||
"timestamp_start": flow.server_conn.timestamp_start,
|
||||
"timestamp_tcp_setup": flow.server_conn.timestamp_tcp_setup,
|
||||
"timestamp_tls_setup": flow.server_conn.timestamp_tls_setup,
|
||||
"timestamp_end": flow.server_conn.timestamp_end,
|
||||
}
|
||||
if flow.server_conn.sni is True:
|
||||
f["server_conn"] = None
|
||||
else:
|
||||
# ideally idna, but we don't want errors
|
||||
f["server_conn"] = always_str(flow.server_conn.sni, "ascii", "backslashreplace")
|
||||
if flow.error:
|
||||
f["error"] = flow.error.get_state()
|
||||
|
||||
|
@ -7,7 +7,7 @@ MITMPROXY = "mitmproxy " + VERSION
|
||||
|
||||
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
|
||||
# for each change in the file format.
|
||||
FLOW_FORMAT_VERSION = 10
|
||||
FLOW_FORMAT_VERSION = 11
|
||||
|
||||
|
||||
def get_dev_version() -> str:
|
||||
|
@ -58,20 +58,20 @@ class TestTlsConfig:
|
||||
|
||||
# Edge case first: We don't have _any_ idea about the server, so we just return "mitmproxy" as subject.
|
||||
entry = ta.get_cert(ctx)
|
||||
assert entry.cert.cn == b"mitmproxy"
|
||||
assert entry.cert.cn == "mitmproxy"
|
||||
|
||||
# Here we have an existing server connection...
|
||||
ctx.server.address = ("server-address.example", 443)
|
||||
with open(tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.crt"), "rb") as f:
|
||||
ctx.server.certificate_list = [certs.Cert.from_pem(f.read())]
|
||||
entry = ta.get_cert(ctx)
|
||||
assert entry.cert.cn == b"example.mitmproxy.org"
|
||||
assert entry.cert.altnames == [b"example.mitmproxy.org", b"server-address.example"]
|
||||
assert entry.cert.cn == "example.mitmproxy.org"
|
||||
assert entry.cert.altnames == ["example.mitmproxy.org", "server-address.example"]
|
||||
|
||||
# And now we also incorporate SNI.
|
||||
ctx.client.sni = b"sni.example"
|
||||
ctx.client.sni = "sni.example"
|
||||
entry = ta.get_cert(ctx)
|
||||
assert entry.cert.altnames == [b"example.mitmproxy.org", b"sni.example"]
|
||||
assert entry.cert.altnames == ["example.mitmproxy.org", "sni.example"]
|
||||
|
||||
def test_tls_clienthello(self):
|
||||
# only really testing for coverage here, there's no point in mirroring the individual conditions
|
||||
@ -127,7 +127,7 @@ class TestTlsConfig:
|
||||
ta = tlsconfig.TlsConfig()
|
||||
with taddons.context(ta) as tctx:
|
||||
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
|
||||
ctx.client.alpn_offers = [b"h2"]
|
||||
ctx.client.alpn_offers = ["h2"]
|
||||
ctx.client.cipher_list = ["TLS_AES_256_GCM_SHA384", "ECDHE-RSA-AES128-SHA"]
|
||||
ctx.server.address = ("example.mitmproxy.org", 443)
|
||||
|
||||
@ -185,8 +185,8 @@ class TestTlsConfig:
|
||||
ta.tls_start(tls_start)
|
||||
assert ctx.server.alpn_offers == expected
|
||||
|
||||
assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",))
|
||||
assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",))
|
||||
assert_alpn(True, tls.HTTP_ALPNS + ("foo",), tls.HTTP_ALPNS + ("foo",))
|
||||
assert_alpn(False, tls.HTTP_ALPNS + ("foo",), tls.HTTP1_ALPNS + ("foo",))
|
||||
assert_alpn(True, [], tls.HTTP_ALPNS)
|
||||
assert_alpn(False, [], tls.HTTP1_ALPNS)
|
||||
ctx.client.timestamp_tls_setup = time.time()
|
||||
|
@ -29,7 +29,7 @@ def test_sslkeylogfile(tdata, monkeypatch):
|
||||
Path(tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.pem")),
|
||||
Path(tdata.path("mitmproxy/net/data/dhparam.pem"))
|
||||
)
|
||||
entry = store.get_cert(b"example.com", [], None)
|
||||
entry = store.get_cert("example.com", [], None)
|
||||
|
||||
cctx = tls.create_proxy_server_context(
|
||||
min_version=tls.DEFAULT_MIN_VERSION,
|
||||
@ -105,12 +105,12 @@ class TestClientHello:
|
||||
)
|
||||
c = tls.ClientHello(data)
|
||||
assert repr(c)
|
||||
assert c.sni == b'example.com'
|
||||
assert c.sni == 'example.com'
|
||||
assert c.cipher_suites == [
|
||||
49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161,
|
||||
49171, 49162, 49172, 156, 157, 47, 53, 10
|
||||
]
|
||||
assert c.alpn_protocols == [b'h2', b'http/1.1']
|
||||
assert c.alpn_protocols == ['h2', 'http/1.1']
|
||||
assert c.extensions == [
|
||||
(65281, b'\x00'),
|
||||
(0, b'\x00\x0e\x00\x00\x0bexample.com'),
|
||||
|
@ -44,7 +44,7 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
|
||||
|
||||
|
||||
def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
|
||||
tctx.client.alpn = b"h2"
|
||||
tctx.client.alpn = "h2"
|
||||
frame_factory = FrameFactory()
|
||||
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||
@ -58,7 +58,7 @@ def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
|
||||
|
||||
|
||||
def make_h2(open_connection: OpenConnection) -> None:
|
||||
open_connection.connection.alpn = b"h2"
|
||||
open_connection.connection.alpn = "h2"
|
||||
|
||||
|
||||
def test_simple(tctx):
|
||||
|
@ -208,7 +208,7 @@ def h2_frames(draw):
|
||||
|
||||
def h2_layer(opts):
|
||||
tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), opts)
|
||||
tctx.client.alpn = b"h2"
|
||||
tctx.client.alpn = "h2"
|
||||
|
||||
layer = http.HttpLayer(tctx, HTTPMode.regular)
|
||||
for _ in layer.handle_event(Start()):
|
||||
|
@ -22,7 +22,7 @@ def event_types(events):
|
||||
|
||||
|
||||
def h2_client(tctx: Context) -> Tuple[h2.connection.H2Connection, Playbook]:
|
||||
tctx.client.alpn = b"h2"
|
||||
tctx.client.alpn = "h2"
|
||||
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||
conn = h2.connection.H2Connection()
|
||||
|
@ -202,7 +202,7 @@ def test_reverse_proxy_tcp_over_tls(tctx: Context, monkeypatch, patch, connectio
|
||||
>> reply_tls_start()
|
||||
<< SendData(tctx.server, data)
|
||||
)
|
||||
assert tls.parse_client_hello(data()).sni == b"localhost"
|
||||
assert tls.parse_client_hello(data()).sni == "localhost"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("connection_strategy", ["eager", "lazy"])
|
||||
|
@ -76,7 +76,7 @@ def test_get_client_hello():
|
||||
|
||||
|
||||
def test_parse_client_hello():
|
||||
assert tls.parse_client_hello(client_hello_with_extensions).sni == b"example.com"
|
||||
assert tls.parse_client_hello(client_hello_with_extensions).sni == "example.com"
|
||||
assert tls.parse_client_hello(client_hello_with_extensions[:50]) is None
|
||||
with pytest.raises(ValueError):
|
||||
tls.parse_client_hello(client_hello_with_extensions[:183] + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00')
|
||||
@ -188,7 +188,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
||||
tls_start.ssl_conn = SSL.Connection(ssl_context)
|
||||
tls_start.ssl_conn.set_connect_state()
|
||||
# Set SNI
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni)
|
||||
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode("ascii"))
|
||||
|
||||
# Manually enable hostname verification.
|
||||
# Recent OpenSSL versions provide slightly nicer ways to do this, but they are not exposed in
|
||||
@ -202,7 +202,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
||||
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
|
||||
)
|
||||
SSL._openssl_assert(
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni, 0) == 1
|
||||
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode("ascii"), 0) == 1
|
||||
)
|
||||
|
||||
return tutils.reply(*args, side_effect=make_conn, **kwargs)
|
||||
@ -227,7 +227,7 @@ class TestServerTLS:
|
||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||
tctx.server.state = ConnectionState.OPEN
|
||||
tctx.server.address = ("example.mitmproxy.org", 443)
|
||||
tctx.server.sni = b"example.mitmproxy.org"
|
||||
tctx.server.sni = "example.mitmproxy.org"
|
||||
|
||||
tssl = SSLTest(server_side=True)
|
||||
|
||||
@ -280,7 +280,7 @@ class TestServerTLS:
|
||||
"""If the certificate is not trusted, we should fail."""
|
||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||
tctx.server.address = ("wrong.host.mitmproxy.org", 443)
|
||||
tctx.server.sni = b"wrong.host.mitmproxy.org"
|
||||
tctx.server.sni = "wrong.host.mitmproxy.org"
|
||||
|
||||
tssl = SSLTest(server_side=True)
|
||||
|
||||
@ -316,7 +316,7 @@ class TestServerTLS:
|
||||
def test_remote_speaks_no_tls(self, tctx):
|
||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||
tctx.server.state = ConnectionState.OPEN
|
||||
tctx.server.sni = b"example.mitmproxy.org"
|
||||
tctx.server.sni = "example.mitmproxy.org"
|
||||
|
||||
# send ClientHello, receive random garbage back
|
||||
data = tutils.Placeholder(bytes)
|
||||
@ -345,7 +345,7 @@ def make_client_tls_layer(
|
||||
|
||||
# Add some server config, this is needed anyways.
|
||||
tctx.server.address = ("example.mitmproxy.org", 443)
|
||||
tctx.server.sni = b"example.mitmproxy.org"
|
||||
tctx.server.sni = "example.mitmproxy.org"
|
||||
|
||||
tssl_client = SSLTest(**kwargs)
|
||||
# Start handshake.
|
||||
@ -446,8 +446,8 @@ class TestClientTLS:
|
||||
assert tctx.client.tls_established
|
||||
assert tctx.server.tls_established
|
||||
assert tctx.server.sni == tctx.client.sni
|
||||
assert tctx.client.alpn == b"quux"
|
||||
assert tctx.server.alpn == b"quux"
|
||||
assert tctx.client.alpn == "quux"
|
||||
assert tctx.server.alpn == "quux"
|
||||
_test_echo(playbook, tssl_server, tctx.server)
|
||||
_test_echo(playbook, tssl_client, tctx.client)
|
||||
|
||||
|
@ -46,10 +46,10 @@ class TestCertStore:
|
||||
|
||||
def test_create_explicit(self, tmpdir):
|
||||
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
assert ca.get_cert(b"foo", [])
|
||||
assert ca.get_cert("foo", [])
|
||||
|
||||
ca2 = certs.CertStore.from_store(str(tmpdir), "test", 2048)
|
||||
assert ca2.get_cert(b"foo", [])
|
||||
assert ca2.get_cert("foo", [])
|
||||
|
||||
assert ca.default_ca.serial == ca2.default_ca.serial
|
||||
|
||||
@ -57,51 +57,51 @@ class TestCertStore:
|
||||
assert tstore.get_cert(None, []).cert.cn is None
|
||||
|
||||
def test_sans(self, tstore):
|
||||
c1 = tstore.get_cert(b"foo.com", [b"*.bar.com"])
|
||||
tstore.get_cert(b"foo.bar.com", [])
|
||||
c1 = tstore.get_cert("foo.com", ["*.bar.com"])
|
||||
tstore.get_cert("foo.bar.com", [])
|
||||
# assert c1 == c2
|
||||
c3 = tstore.get_cert(b"bar.com", [])
|
||||
c3 = tstore.get_cert("bar.com", [])
|
||||
assert not c1 == c3
|
||||
|
||||
def test_sans_change(self, tstore):
|
||||
tstore.get_cert(b"foo.com", [b"*.bar.com"])
|
||||
entry = tstore.get_cert(b"foo.bar.com", [b"*.baz.com"])
|
||||
assert b"*.baz.com" in entry.cert.altnames
|
||||
tstore.get_cert("foo.com", ["*.bar.com"])
|
||||
entry = tstore.get_cert("foo.bar.com", ["*.baz.com"])
|
||||
assert "*.baz.com" in entry.cert.altnames
|
||||
|
||||
def test_expire(self, tstore):
|
||||
tstore.STORE_CAP = 3
|
||||
tstore.get_cert(b"one.com", [])
|
||||
tstore.get_cert(b"two.com", [])
|
||||
tstore.get_cert(b"three.com", [])
|
||||
tstore.get_cert("one.com", [])
|
||||
tstore.get_cert("two.com", [])
|
||||
tstore.get_cert("three.com", [])
|
||||
|
||||
assert (b"one.com", ()) in tstore.certs
|
||||
assert (b"two.com", ()) in tstore.certs
|
||||
assert (b"three.com", ()) in tstore.certs
|
||||
assert ("one.com", ()) in tstore.certs
|
||||
assert ("two.com", ()) in tstore.certs
|
||||
assert ("three.com", ()) in tstore.certs
|
||||
|
||||
tstore.get_cert(b"one.com", [])
|
||||
tstore.get_cert("one.com", [])
|
||||
|
||||
assert (b"one.com", ()) in tstore.certs
|
||||
assert (b"two.com", ()) in tstore.certs
|
||||
assert (b"three.com", ()) in tstore.certs
|
||||
assert ("one.com", ()) in tstore.certs
|
||||
assert ("two.com", ()) in tstore.certs
|
||||
assert ("three.com", ()) in tstore.certs
|
||||
|
||||
tstore.get_cert(b"four.com", [])
|
||||
tstore.get_cert("four.com", [])
|
||||
|
||||
assert (b"one.com", ()) not in tstore.certs
|
||||
assert (b"two.com", ()) in tstore.certs
|
||||
assert (b"three.com", ()) in tstore.certs
|
||||
assert (b"four.com", ()) in tstore.certs
|
||||
assert ("one.com", ()) not in tstore.certs
|
||||
assert ("two.com", ()) in tstore.certs
|
||||
assert ("three.com", ()) in tstore.certs
|
||||
assert ("four.com", ()) in tstore.certs
|
||||
|
||||
def test_overrides(self, tmp_path):
|
||||
ca1 = certs.CertStore.from_store(tmp_path / "ca1", "test", 2048)
|
||||
ca2 = certs.CertStore.from_store(tmp_path / "ca2", "test", 2048)
|
||||
assert not ca1.default_ca.serial == ca2.default_ca.serial
|
||||
|
||||
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
|
||||
dc = ca2.get_cert("foo.com", ["sans.example.com"])
|
||||
dcp = tmp_path / "dc"
|
||||
dcp.write_bytes(dc.cert.to_pem())
|
||||
ca1.add_cert_file("foo.com", dcp)
|
||||
|
||||
ret = ca1.get_cert(b"foo.com", [])
|
||||
ret = ca1.get_cert("foo.com", [])
|
||||
assert ret.cert.serial == dc.cert.serial
|
||||
|
||||
def test_create_dhparams(self, tmp_path):
|
||||
@ -124,13 +124,13 @@ class TestDummyCert:
|
||||
r = certs.dummy_cert(
|
||||
tstore.default_privatekey,
|
||||
tstore.default_ca._cert,
|
||||
b"foo.com",
|
||||
[b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"],
|
||||
b"Foo Ltd."
|
||||
"foo.com",
|
||||
["one.com", "two.com", "*.three.com", "127.0.0.1"],
|
||||
"Foo Ltd."
|
||||
)
|
||||
assert r.cn == b"foo.com"
|
||||
assert r.altnames == [b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"]
|
||||
assert r.organization == b"Foo Ltd."
|
||||
assert r.cn == "foo.com"
|
||||
assert r.altnames == ["one.com", "two.com", "*.three.com", "127.0.0.1"]
|
||||
assert r.organization == "Foo Ltd."
|
||||
|
||||
r = certs.dummy_cert(
|
||||
tstore.default_privatekey,
|
||||
@ -150,14 +150,14 @@ class TestCert:
|
||||
with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f:
|
||||
d = f.read()
|
||||
c1 = certs.Cert.from_pem(d)
|
||||
assert c1.cn == b"google.com"
|
||||
assert c1.cn == "google.com"
|
||||
assert len(c1.altnames) == 436
|
||||
assert c1.organization == b"Google Inc"
|
||||
assert c1.organization == "Google Inc"
|
||||
|
||||
with open(tdata.path("mitmproxy/net/data/text_cert_2"), "rb") as f:
|
||||
d = f.read()
|
||||
c2 = certs.Cert.from_pem(d)
|
||||
assert c2.cn == b"www.inode.co.nz"
|
||||
assert c2.cn == "www.inode.co.nz"
|
||||
assert len(c2.altnames) == 2
|
||||
assert c2.fingerprint()
|
||||
assert c2.notbefore
|
||||
@ -197,4 +197,4 @@ class TestCert:
|
||||
def test_from_store_with_passphrase(self, tdata, tstore):
|
||||
tstore.add_cert_file("*", Path(tdata.path("mitmproxy/data/mitmproxy.pem")), b"password")
|
||||
|
||||
assert tstore.get_cert(b"foo", [])
|
||||
assert tstore.get_cert("foo", [])
|
||||
|
@ -25,12 +25,10 @@ def test_format_keyvals():
|
||||
("ee", None),
|
||||
]
|
||||
)
|
||||
wrapped = urwid.BoxAdapter(
|
||||
urwid.ListBox(
|
||||
wrapped = urwid.Pile(
|
||||
urwid.SimpleFocusListWalker(
|
||||
common.format_keyvals([("foo", "bar")])
|
||||
)
|
||||
), 1
|
||||
)
|
||||
assert wrapped.render((30,))
|
||||
assert common.format_keyvals(
|
||||
|
Loading…
Reference in New Issue
Block a user