sni/alpn: bytes -> str

This commit is contained in:
Maximilian Hils 2020-12-30 20:23:25 +01:00
parent 85c5275ec1
commit abbe9eeb79
23 changed files with 186 additions and 201 deletions

View File

@ -66,7 +66,7 @@ class NextLayer:
pass
else:
if sni:
hostnames.append(sni.decode("idna"))
hostnames.append(sni)
if not hostnames:
return False

View File

@ -8,6 +8,7 @@ from mitmproxy.net import tls as net_tls
from mitmproxy.options import CONF_BASENAME
from mitmproxy.proxy import context
from mitmproxy.proxy.layers import tls
from mitmproxy.utils.strutils import always_bytes
# We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default.
# https://ssl-config.mozilla.org/#config=old
@ -35,7 +36,7 @@ def alpn_select_callback(conn: SSL.Connection, options: List[bytes]) -> Any:
return server_alpn
http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS
for alpn in options: # client sends in order of preference, so we are nice and respect that.
if alpn in http_alpns:
if alpn.decode(errors="replace") in http_alpns:
return alpn
else:
return SSL.NO_OVERLAPPING_PROTOCOLS
@ -137,7 +138,7 @@ class TlsConfig:
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
tls_start.ssl_conn.set_app_data(AppData(
server_alpn=server.alpn,
server_alpn=always_bytes(server.alpn, "utf8", "replace"),
http2=ctx.options.http2,
))
tls_start.ssl_conn.set_accept_state()
@ -153,15 +154,15 @@ class TlsConfig:
verify = net_tls.Verify.VERIFY_PEER
if server.sni is True:
server.sni = client.sni or server.address[0].encode()
sni = server.sni or None # make sure that false-y values are None
server.sni = client.sni or server.address[0]
sni: Optional[bytes] = server.sni.encode("ascii") if server.sni else None
if not server.alpn_offers:
if client.alpn_offers:
if ctx.options.http2:
server.alpn_offers = tuple(client.alpn_offers)
else:
server.alpn_offers = tuple(x for x in client.alpn_offers if x != b"h2")
server.alpn_offers = tuple(x for x in client.alpn_offers if x != "h2")
elif client.tls_established:
# We would perfectly support HTTP/1 -> HTTP/2, but we want to keep things on the same protocol version.
# There are some edge cases where we want to mirror the regular server's behavior accurately,
@ -171,6 +172,7 @@ class TlsConfig:
server.alpn_offers = tls.HTTP_ALPNS
else:
server.alpn_offers = tls.HTTP1_ALPNS
alpn_offers: List[bytes] = [alpn.encode() for alpn in server.alpn_offers]
if not server.cipher_list and ctx.options.ciphers_server:
server.cipher_list = ctx.options.ciphers_server.split(":")
@ -183,7 +185,7 @@ class TlsConfig:
if os.path.isfile(client_certs):
client_cert = client_certs
else:
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
server_name: str = server.sni or server.address[0]
p = os.path.join(client_certs, f"{server_name}.pem")
if os.path.isfile(p):
client_cert = p
@ -197,11 +199,11 @@ class TlsConfig:
ca_path=ctx.options.ssl_verify_upstream_trusted_confdir,
ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca,
client_cert=client_cert,
alpn_protos=server.alpn_offers,
alpn_protos=alpn_offers,
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
tls_start.ssl_conn.set_tlsext_host_name(server.sni)
tls_start.ssl_conn.set_tlsext_host_name(sni)
tls_start.ssl_conn.set_connect_state()
def running(self):
@ -250,8 +252,8 @@ class TlsConfig:
This function determines the Common Name (CN), Subject Alternative Names (SANs) and Organization Name
our certificate should have and then fetches a matching cert from the certstore.
"""
altnames: List[bytes] = []
organization: Optional[bytes] = None
altnames: List[str] = []
organization: Optional[str] = None
# Use upstream certificate if available.
if conn_context.server.certificate_list:
@ -266,11 +268,11 @@ class TlsConfig:
if conn_context.client.sni:
altnames.append(conn_context.client.sni)
elif conn_context.server.address:
altnames.append(conn_context.server.address[0].encode("idna"))
altnames.append(conn_context.server.address[0])
# As a last resort, add *something* so that we have a certificate to serve.
if not altnames:
altnames.append(b"mitmproxy")
altnames.append("mitmproxy")
# only keep first occurrence of each hostname
altnames = list(dict.fromkeys(altnames))

View File

@ -109,21 +109,21 @@ class Cert(serializable.Serializable):
return public_key.__class__.__name__.replace("PublicKey", "").replace("_", ""), -1
@property
def cn(self) -> Optional[bytes]: # TODO: make this return str
def cn(self) -> Optional[str]:
attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
if attrs:
return attrs[0].value.encode()
return attrs[0].value
return None
@property
def organization(self) -> Optional[bytes]: # TODO: make this return str
def organization(self) -> Optional[str]:
attrs = self._cert.subject.get_attributes_for_oid(x509.NameOID.ORGANIZATION_NAME)
if attrs:
return attrs[0].value.encode()
return attrs[0].value
return None
@property
def altnames(self) -> List[bytes]: # TODO: make this return str
def altnames(self) -> List[str]:
"""
Get all SubjectAlternativeName DNS altnames.
"""
@ -133,9 +133,9 @@ class Cert(serializable.Serializable):
return []
else:
return (
[x.encode() for x in ext.get_values_for_type(x509.DNSName)]
ext.get_values_for_type(x509.DNSName)
+
[str(x).encode() for x in ext.get_values_for_type(x509.IPAddress)]
[str(x) for x in ext.get_values_for_type(x509.IPAddress)]
)
@ -191,9 +191,9 @@ def create_ca(
def dummy_cert(
privkey: rsa.RSAPrivateKey,
cacert: x509.Certificate,
commonname: Optional[bytes],
sans: List[bytes],
organization: Optional[bytes] = None,
commonname: Optional[str],
sans: List[str],
organization: Optional[str] = None,
) -> Cert:
"""
Generates a dummy certificate.
@ -206,10 +206,6 @@ def dummy_cert(
Returns cert if operation succeeded, None if not.
"""
XX_commonname: Optional[str] = commonname.decode("idna") if commonname else None
XX_organization: Optional[str] = organization.decode() if organization else None
XX_sans: List[str] = [x.decode("ascii") for x in sans]
builder = x509.CertificateBuilder()
builder = builder.issuer_name(cacert.subject)
builder = builder.add_extension(x509.ExtendedKeyUsage([ExtendedKeyUsageOID.SERVER_AUTH]), critical=False)
@ -221,19 +217,19 @@ def dummy_cert(
subject = []
is_valid_commonname = (
XX_commonname is not None and len(XX_commonname) < 64
commonname is not None and len(commonname) < 64
)
if is_valid_commonname:
assert XX_commonname is not None
subject.append(x509.NameAttribute(NameOID.COMMON_NAME, XX_commonname))
if XX_organization is not None:
assert XX_organization is not None
subject.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, XX_organization))
assert commonname is not None
subject.append(x509.NameAttribute(NameOID.COMMON_NAME, commonname))
if organization is not None:
assert organization is not None
subject.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization))
builder = builder.subject_name(x509.Name(subject))
builder = builder.serial_number(x509.random_serial_number())
ss: List[x509.GeneralName] = []
for x in XX_sans:
for x in sans:
try:
ip = ipaddress.ip_address(x)
except ValueError:
@ -253,8 +249,8 @@ class CertStoreEntry:
chain_file: Optional[Path]
TCustomCertId = bytes # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = Tuple[Optional[bytes], Tuple[bytes, ...]] # (common_name, sans)
TCustomCertId = str # manually provided certs (e.g. mitmproxy's --certs)
TGeneratedCertId = Tuple[Optional[str], Tuple[str, ...]] # (common_name, sans)
TCertId = Union[TCustomCertId, TGeneratedCertId]
DHParams = NewType("DHParams", bytes)
@ -415,10 +411,10 @@ class CertStore:
self.add_cert(
CertStoreEntry(cert, key, path),
spec.encode("idna")
spec
)
def add_cert(self, entry: CertStoreEntry, *names: bytes) -> None:
def add_cert(self, entry: CertStoreEntry, *names: str) -> None:
"""
Adds a cert to the certstore. We register the CN in the cert plus
any SANs, and also the list of names provided as an argument.
@ -431,22 +427,22 @@ class CertStore:
self.certs[i] = entry
@staticmethod
def asterisk_forms(dn: bytes) -> List[bytes]:
def asterisk_forms(dn: str) -> List[str]:
"""
Return all asterisk forms for a domain. For example, for www.example.com this will return
[b"www.example.com", b"*.example.com", b"*.com"]. The single wildcard "*" is omitted.
"""
parts = dn.split(b".")
parts = dn.split(".")
ret = [dn]
for i in range(1, len(parts)):
ret.append(b"*." + b".".join(parts[i:]))
ret.append("*." + ".".join(parts[i:]))
return ret
def get_cert(
self,
commonname: Optional[bytes],
sans: List[bytes],
organization: Optional[bytes] = None
commonname: Optional[str],
sans: List[str],
organization: Optional[str] = None
) -> CertStoreEntry:
"""
commonname: Common name for the generated certificate. Must be a
@ -462,7 +458,7 @@ class CertStore:
potential_keys.extend(self.asterisk_forms(commonname))
for s in sans:
potential_keys.extend(self.asterisk_forms(s))
potential_keys.append(b"*")
potential_keys.append("*")
potential_keys.append((commonname, tuple(sans)))
name = next(

View File

@ -230,6 +230,25 @@ def convert_9_10(data):
return data
def convert_10_11(data):
data["version"] = 11
def conv_conn(conn):
conn["sni"] = strutils.always_str(conn["sni"], "ascii", "backslashreplace")
conn["alpn"] = strutils.always_str(conn.pop("alpn_proto_negotiated"), "utf8", "backslashreplace")
conn["alpn_offers"] = [
strutils.always_str(alpn, "utf8", "backslashreplace")
for alpn in (conn["alpn_offers"] or [])
]
conv_conn(data["client_conn"])
conv_conn(data["server_conn"])
if data["server_conn"]["via"]:
conv_conn(data["server_conn"]["via"])
return data
def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
@ -287,6 +306,7 @@ converters = {
7: convert_7_8,
8: convert_8_9,
9: convert_9_10,
10: convert_10_11,
}

View File

@ -279,7 +279,7 @@ class ClientHello:
return self._client_hello.cipher_suites.cipher_suites
@property
def sni(self) -> Optional[bytes]:
def sni(self) -> Optional[str]:
if self._client_hello.extensions:
for extension in self._client_hello.extensions.extensions:
is_valid_sni_extension = (
@ -289,15 +289,18 @@ class ClientHello:
check.is_valid_host(extension.body.server_names[0].host_name)
)
if is_valid_sni_extension:
return extension.body.server_names[0].host_name
return extension.body.server_names[0].host_name.decode("ascii")
return None
@property
def alpn_protocols(self):
def alpn_protocols(self) -> List[str]:
if self._client_hello.extensions:
for extension in self._client_hello.extensions.extensions:
if extension.type == 0x10:
return list(x.name for x in extension.body.alpn_protocols)
try:
return [x.name.decode() for x in extension.body.alpn_protocols]
except UnicodeDecodeError:
return []
return []
@property

View File

@ -56,8 +56,8 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
TLS version, with the exception of the end-entity certificate which
MUST be first.
"""
alpn: Optional[bytes] = None
alpn_offers: Sequence[bytes] = ()
alpn: Optional[str] = None
alpn_offers: Sequence[str] = ()
# we may want to add SSL_CIPHER_description here, but that's currently not exposed by cryptography
cipher: Optional[str] = None
@ -65,7 +65,7 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
cipher_list: Sequence[str] = ()
"""Ciphers accepted by the proxy server on this connection."""
tls_version: Optional[str] = None
sni: Union[bytes, Literal[True], None]
sni: Union[str, Literal[True], None]
timestamp_end: Optional[float] = None
"""Connection end timestamp"""
@ -98,7 +98,9 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
@property
def alpn_proto_negotiated(self) -> Optional[bytes]: # pragma: no cover
warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", DeprecationWarning)
return self.alpn
if self.alpn is not None:
return self.alpn.encode()
return None
class Client(Connection):
@ -110,7 +112,7 @@ class Client(Connection):
"""TCP SYN received"""
mitmcert: Optional[certs.Cert] = None
sni: Union[bytes, None] = None
sni: Union[str, None] = None
def __init__(self, peername, sockname, timestamp_start):
self.id = str(uuid.uuid4())
@ -123,7 +125,7 @@ class Client(Connection):
# This means we need to add all new fields to the old implementation.
return {
'address': self.peername,
'alpn_proto_negotiated': self.alpn,
'alpn': self.alpn,
'cipher_name': self.cipher,
'id': self.id,
'mitmcert': self.mitmcert.get_state() if self.mitmcert is not None else None,
@ -156,7 +158,7 @@ class Client(Connection):
def set_state(self, state):
self.peername = tuple(state["address"]) if state["address"] else None
self.alpn = state["alpn_proto_negotiated"]
self.alpn = state["alpn"]
self.cipher = state["cipher_name"]
self.id = state["id"]
self.sni = state["sni"]
@ -217,7 +219,7 @@ class Server(Connection):
timestamp_tcp_setup: Optional[float] = None
"""TCP ACK received"""
sni: Union[bytes, Literal[True], None] = True
sni: Union[str, Literal[True], None] = True
"""True: client SNI, False: no SNI, bytes: custom value"""
via: Optional[server_spec.ServerSpec] = None
@ -228,7 +230,7 @@ class Server(Connection):
def get_state(self):
return {
'address': self.address,
'alpn_proto_negotiated': self.alpn,
'alpn': self.alpn,
'id': self.id,
'ip_address': self.peername,
'sni': self.sni,
@ -259,7 +261,7 @@ class Server(Connection):
def set_state(self, state):
self.address = tuple(state["address"]) if state["address"] else None
self.alpn = state["alpn_proto_negotiated"]
self.alpn = state["alpn"]
self.id = state["id"]
self.peername = tuple(state["ip_address"]) if state["ip_address"] else None
self.sni = state["sni"]

View File

@ -351,7 +351,7 @@ class HttpStream(layer.Layer):
yield HttpErrorHook(self.flow)
# For HTTP/2 we only want to kill the specific stream, for HTTP/1 we want to kill the connection
# *without* sending an HTTP response (that could be achieved by the user by setting flow.response).
if self.context.client.alpn == b"h2":
if self.context.client.alpn == "h2":
yield SendHttp(ResponseProtocolError(self.stream_id, "killed"), self.context.client)
else:
if self.context.client.state & ConnectionState.CAN_WRITE:
@ -435,7 +435,7 @@ class HttpStream(layer.Layer):
stack = tunnel.LayerStack()
if self.context.server.via.scheme == "https":
http_proxy.sni = self.context.server.via.address[0].encode()
http_proxy.sni = self.context.server.via.address[0]
stack /= tls.ServerTLSLayer(self.context, http_proxy)
stack /= _upstream_proxy.HttpUpstreamProxy(self.context, http_proxy, True)
@ -532,7 +532,7 @@ class HttpLayer(layer.Layer):
self.command_sources = {}
http_conn: HttpConnection
if self.context.client.alpn == b"h2":
if self.context.client.alpn == "h2":
http_conn = Http2Server(context.fork())
else:
http_conn = Http1Server(context.fork())
@ -606,10 +606,10 @@ class HttpLayer(layer.Layer):
for connection in self.connections:
# see "tricky multiplexing edge case" in make_http_connection for an explanation
conn_is_pending_or_h2 = (
connection.alpn == b"h2"
connection.alpn == "h2"
or connection in self.waiting_for_establishment
)
h2_to_h1 = self.context.client.alpn == b"h2" and not conn_is_pending_or_h2
h2_to_h1 = self.context.client.alpn == "h2" and not conn_is_pending_or_h2
connection_suitable = (
event.connection_spec_matches(connection)
and not h2_to_h1
@ -635,7 +635,7 @@ class HttpLayer(layer.Layer):
context.server = Server(event.address)
if event.tls:
context.server.sni = event.address[0].encode()
context.server.sni = event.address[0]
if event.via:
assert event.via.scheme in ("http", "https")
@ -643,7 +643,7 @@ class HttpLayer(layer.Layer):
if event.via.scheme == "https":
http_proxy.alpn_offers = tls.HTTP_ALPNS
http_proxy.sni = event.via.address[0].encode()
http_proxy.sni = event.via.address[0]
stack /= tls.ServerTLSLayer(context, http_proxy)
send_connect = not (self.mode == HTTPMode.upstream and not event.tls)
@ -679,7 +679,7 @@ class HttpLayer(layer.Layer):
# that neither have a content-length specified nor a chunked transfer encoding.
# We can't process these two flows to the same h1 connection as they would both have
# "read until eof" semantics. The only workaround left is to open a separate connection for each flow.
if not command.err and self.context.client.alpn == b"h2" and command.connection.alpn != b"h2":
if not command.err and self.context.client.alpn == "h2" and command.connection.alpn != "h2":
for cmd in waiting[1:]:
yield from self.get_connection(cmd, reuse=False)
break
@ -695,7 +695,7 @@ class HttpClient(layer.Layer):
err = yield commands.OpenConnection(self.context.server)
if not err:
child_layer: layer.Layer
if self.context.server.alpn == b"h2":
if self.context.server.alpn == "h2":
child_layer = Http2Client(self.context)
else:
child_layer = Http1Client(self.context)

View File

@ -42,7 +42,7 @@ class ReverseProxy(DestinationKnown):
if spec.scheme not in ("http", "tcp"):
if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0].encode()
self.context.server.sni = spec.address[0]
self.child_layer = tls.ServerTLSLayer(self.context)
else:
self.child_layer = layer.NextLayer(self.context)

View File

@ -91,8 +91,8 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
return None
HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS
HTTP1_ALPNS = ("http/1.1", "http/1.0", "http/0.9")
HTTP_ALPNS = ("h2",) + HTTP1_ALPNS
# We need these classes as hooks can only have one argument at the moment.
@ -183,6 +183,8 @@ class _TLSLayer(tunnel.TunnelLayer):
err = f"OpenSSL {e!r}"
return False, err
else:
# Here we set all attributes that are only known *after* the handshake.
# Get all peer certificates.
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_cert_chain.html
# If called on the client side, the stack also contains the peer's certificate; if called on the server
@ -194,11 +196,8 @@ class _TLSLayer(tunnel.TunnelLayer):
all_certs.insert(0, cert)
self.conn.timestamp_tls_setup = time.time()
self.conn.sni = self.tls.get_servername()
self.conn.alpn = self.tls.get_alpn_proto_negotiated()
self.conn.alpn = self.tls.get_alpn_proto_negotiated().decode()
self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs]
if isinstance(self.conn, context.Client):
self.conn.mitmcert = certs.Cert.from_pyopenssl(self.tls.get_certificate())
self.conn.cipher = self.tls.get_cipher_name()
self.conn.tls_version = self.tls.get_protocol_version_name()
if self.debug:
@ -339,8 +338,7 @@ class ClientTLSLayer(_TLSLayer):
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
if self.conn.sni:
assert isinstance(self.conn.sni, bytes)
dest = self.conn.sni.decode("idna")
dest = self.conn.sni
else:
dest = human.format_address(self.context.server.address)
if err.startswith("Cannot parse ClientHello"):

View File

@ -427,7 +427,8 @@ if __name__ == "__main__": # pragma: no cover
tls_start.ssl_conn.set_accept_state()
else:
tls_start.ssl_conn.set_connect_state()
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni)
if tls_start.context.client.sni is not None:
tls_start.ssl_conn.set_tlsext_host_name(tls_start.context.client.sni.encode())
await SimpleConnectionHandler(reader, writer, opts, {
"next_layer": next_layer,

View File

@ -158,7 +158,7 @@ def tclient_conn() -> context.Client:
timestamp_end=946681206,
sni="address",
cipher_name="cipher",
alpn_proto_negotiated=b"http/1.1",
alpn="http/1.1",
tls_version="TLSv1.2",
tls_extensions=[(0x00, bytes.fromhex("000e00000b6578616d"))],
state=0,
@ -185,7 +185,7 @@ def tserver_conn() -> context.Server:
timestamp_end=946681205,
tls_established=True,
sni="address",
alpn_proto_negotiated=None,
alpn=None,
tls_version="TLSv1.2",
via=None,
state=0,

View File

@ -5,7 +5,6 @@ import mitmproxy.flow
from mitmproxy import http
from mitmproxy.tools.console import common, searchable
from mitmproxy.utils import human
from mitmproxy.utils import strutils
def maybe_timestamp(base, attr):
@ -40,23 +39,23 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
text.append(urwid.Text([("head", "Metadata:")]))
text.extend(common.format_keyvals(parts, indent=4))
if sc is not None and sc.ip_address:
if sc is not None and sc.peername:
text.append(urwid.Text([("head", "Server Connection:")]))
parts = [
("Address", human.format_address(sc.address)),
]
if sc.ip_address:
parts.append(("Resolved Address", human.format_address(sc.ip_address)))
if sc.peername:
parts.append(("Resolved Address", human.format_address(sc.peername)))
if resp:
parts.append(("HTTP Version", resp.http_version))
if sc.alpn_proto_negotiated:
parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn_proto_negotiated)))
if sc.alpn:
parts.append(("ALPN", sc.alpn))
text.extend(
common.format_keyvals(parts, indent=4)
)
c = sc.cert
c = sc.certificate_list[0]
if c:
text.append(urwid.Text([("head", "Server Certificate:")]))
parts = [
@ -65,39 +64,12 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
("Valid to", str(c.notafter)),
("Valid from", str(c.notbefore)),
("Serial", str(c.serial)),
(
"Subject",
urwid.BoxAdapter(
urwid.ListBox(
common.format_keyvals(
c.subject,
key_format="highlight"
)
),
len(c.subject)
)
),
(
"Issuer",
urwid.BoxAdapter(
urwid.ListBox(
common.format_keyvals(
c.issuer,
key_format="highlight"
)
),
len(c.issuer)
)
)
("Subject", urwid.Pile(common.format_keyvals(c.subject, key_format="highlight"))),
("Issuer", urwid.Pile(common.format_keyvals(c.issuer, key_format="highlight")))
]
if c.altnames:
parts.append(
(
"Alt names",
", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames)
)
)
parts.append(("Alt names", ", ".join(c.altnames)))
text.extend(
common.format_keyvals(parts, indent=4)
)
@ -106,19 +78,18 @@ def flowdetails(state, flow: mitmproxy.flow.Flow):
text.append(urwid.Text([("head", "Client Connection:")]))
parts = [
("Address", "{}:{}".format(cc.address[0], cc.address[1])),
("Address", human.format_address(cc.peername)),
]
if req:
parts.append(("HTTP Version", req.http_version))
if cc.tls_version:
parts.append(("TLS Version", cc.tls_version))
if cc.sni:
parts.append(("Server Name Indication",
strutils.bytes_to_escaped_str(strutils.always_bytes(cc.sni, "idna"))))
if cc.cipher_name:
parts.append(("Cipher Name", cc.cipher_name))
if cc.alpn_proto_negotiated:
parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn_proto_negotiated)))
parts.append(("Server Name Indication", cc.sni))
if cc.cipher:
parts.append(("Cipher Name", cc.cipher))
if cc.alpn:
parts.append(("ALPN", cc.alpn))
text.extend(
common.format_keyvals(parts, indent=4)

View File

@ -20,7 +20,6 @@ from mitmproxy import io
from mitmproxy import log
from mitmproxy import optmanager
from mitmproxy import version
from mitmproxy.utils.strutils import always_str
def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
@ -47,10 +46,9 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"timestamp_start": flow.client_conn.timestamp_start,
"timestamp_tls_setup": flow.client_conn.timestamp_tls_setup,
"timestamp_end": flow.client_conn.timestamp_end,
# ideally idna, but we don't want errors
"sni": always_str(flow.client_conn.sni, "ascii", "backslashreplace"),
"sni": flow.client_conn.sni,
"cipher_name": flow.client_conn.cipher,
"alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"),
"alpn_proto_negotiated": flow.client_conn.alpn,
"tls_version": flow.client_conn.tls_version,
}
@ -61,18 +59,14 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"ip_address": flow.server_conn.peername,
"source_address": flow.server_conn.sockname,
"tls_established": flow.server_conn.tls_established,
"alpn_proto_negotiated": always_str(flow.server_conn.alpn, "ascii", "backslashreplace"),
"sni": flow.server_conn.sni,
"alpn_proto_negotiated": flow.server_conn.alpn,
"tls_version": flow.server_conn.tls_version,
"timestamp_start": flow.server_conn.timestamp_start,
"timestamp_tcp_setup": flow.server_conn.timestamp_tcp_setup,
"timestamp_tls_setup": flow.server_conn.timestamp_tls_setup,
"timestamp_end": flow.server_conn.timestamp_end,
}
if flow.server_conn.sni is True:
f["server_conn"] = None
else:
# ideally idna, but we don't want errors
f["server_conn"] = always_str(flow.server_conn.sni, "ascii", "backslashreplace")
if flow.error:
f["error"] = flow.error.get_state()

View File

@ -7,7 +7,7 @@ MITMPROXY = "mitmproxy " + VERSION
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
# for each change in the file format.
FLOW_FORMAT_VERSION = 10
FLOW_FORMAT_VERSION = 11
def get_dev_version() -> str:

View File

@ -58,20 +58,20 @@ class TestTlsConfig:
# Edge case first: We don't have _any_ idea about the server, so we just return "mitmproxy" as subject.
entry = ta.get_cert(ctx)
assert entry.cert.cn == b"mitmproxy"
assert entry.cert.cn == "mitmproxy"
# Here we have an existing server connection...
ctx.server.address = ("server-address.example", 443)
with open(tdata.path("mitmproxy/net/data/verificationcerts/trusted-leaf.crt"), "rb") as f:
ctx.server.certificate_list = [certs.Cert.from_pem(f.read())]
entry = ta.get_cert(ctx)
assert entry.cert.cn == b"example.mitmproxy.org"
assert entry.cert.altnames == [b"example.mitmproxy.org", b"server-address.example"]
assert entry.cert.cn == "example.mitmproxy.org"
assert entry.cert.altnames == ["example.mitmproxy.org", "server-address.example"]
# And now we also incorporate SNI.
ctx.client.sni = b"sni.example"
ctx.client.sni = "sni.example"
entry = ta.get_cert(ctx)
assert entry.cert.altnames == [b"example.mitmproxy.org", b"sni.example"]
assert entry.cert.altnames == ["example.mitmproxy.org", "sni.example"]
def test_tls_clienthello(self):
# only really testing for coverage here, there's no point in mirroring the individual conditions
@ -127,7 +127,7 @@ class TestTlsConfig:
ta = tlsconfig.TlsConfig()
with taddons.context(ta) as tctx:
ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options)
ctx.client.alpn_offers = [b"h2"]
ctx.client.alpn_offers = ["h2"]
ctx.client.cipher_list = ["TLS_AES_256_GCM_SHA384", "ECDHE-RSA-AES128-SHA"]
ctx.server.address = ("example.mitmproxy.org", 443)
@ -185,8 +185,8 @@ class TestTlsConfig:
ta.tls_start(tls_start)
assert ctx.server.alpn_offers == expected
assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",))
assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",))
assert_alpn(True, tls.HTTP_ALPNS + ("foo",), tls.HTTP_ALPNS + ("foo",))
assert_alpn(False, tls.HTTP_ALPNS + ("foo",), tls.HTTP1_ALPNS + ("foo",))
assert_alpn(True, [], tls.HTTP_ALPNS)
assert_alpn(False, [], tls.HTTP1_ALPNS)
ctx.client.timestamp_tls_setup = time.time()

View File

@ -29,7 +29,7 @@ def test_sslkeylogfile(tdata, monkeypatch):
Path(tdata.path("mitmproxy/net/data/verificationcerts/trusted-root.pem")),
Path(tdata.path("mitmproxy/net/data/dhparam.pem"))
)
entry = store.get_cert(b"example.com", [], None)
entry = store.get_cert("example.com", [], None)
cctx = tls.create_proxy_server_context(
min_version=tls.DEFAULT_MIN_VERSION,
@ -105,12 +105,12 @@ class TestClientHello:
)
c = tls.ClientHello(data)
assert repr(c)
assert c.sni == b'example.com'
assert c.sni == 'example.com'
assert c.cipher_suites == [
49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161,
49171, 49162, 49172, 156, 157, 47, 53, 10
]
assert c.alpn_protocols == [b'h2', b'http/1.1']
assert c.alpn_protocols == ['h2', 'http/1.1']
assert c.extensions == [
(65281, b'\x00'),
(0, b'\x00\x0e\x00\x00\x0bexample.com'),

View File

@ -44,7 +44,7 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
tctx.client.alpn = b"h2"
tctx.client.alpn = "h2"
frame_factory = FrameFactory()
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
@ -58,7 +58,7 @@ def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
def make_h2(open_connection: OpenConnection) -> None:
open_connection.connection.alpn = b"h2"
open_connection.connection.alpn = "h2"
def test_simple(tctx):

View File

@ -208,7 +208,7 @@ def h2_frames(draw):
def h2_layer(opts):
tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), opts)
tctx.client.alpn = b"h2"
tctx.client.alpn = "h2"
layer = http.HttpLayer(tctx, HTTPMode.regular)
for _ in layer.handle_event(Start()):

View File

@ -22,7 +22,7 @@ def event_types(events):
def h2_client(tctx: Context) -> Tuple[h2.connection.H2Connection, Playbook]:
tctx.client.alpn = b"h2"
tctx.client.alpn = "h2"
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
conn = h2.connection.H2Connection()

View File

@ -202,7 +202,7 @@ def test_reverse_proxy_tcp_over_tls(tctx: Context, monkeypatch, patch, connectio
>> reply_tls_start()
<< SendData(tctx.server, data)
)
assert tls.parse_client_hello(data()).sni == b"localhost"
assert tls.parse_client_hello(data()).sni == "localhost"
@pytest.mark.parametrize("connection_strategy", ["eager", "lazy"])

View File

@ -76,7 +76,7 @@ def test_get_client_hello():
def test_parse_client_hello():
assert tls.parse_client_hello(client_hello_with_extensions).sni == b"example.com"
assert tls.parse_client_hello(client_hello_with_extensions).sni == "example.com"
assert tls.parse_client_hello(client_hello_with_extensions[:50]) is None
with pytest.raises(ValueError):
tls.parse_client_hello(client_hello_with_extensions[:183] + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00')
@ -188,7 +188,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
tls_start.ssl_conn = SSL.Connection(ssl_context)
tls_start.ssl_conn.set_connect_state()
# Set SNI
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni)
tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode("ascii"))
# Manually enable hostname verification.
# Recent OpenSSL versions provide slightly nicer ways to do this, but they are not exposed in
@ -202,7 +202,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT
)
SSL._openssl_assert(
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni, 0) == 1
SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode("ascii"), 0) == 1
)
return tutils.reply(*args, side_effect=make_conn, **kwargs)
@ -227,7 +227,7 @@ class TestServerTLS:
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.state = ConnectionState.OPEN
tctx.server.address = ("example.mitmproxy.org", 443)
tctx.server.sni = b"example.mitmproxy.org"
tctx.server.sni = "example.mitmproxy.org"
tssl = SSLTest(server_side=True)
@ -280,7 +280,7 @@ class TestServerTLS:
"""If the certificate is not trusted, we should fail."""
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.address = ("wrong.host.mitmproxy.org", 443)
tctx.server.sni = b"wrong.host.mitmproxy.org"
tctx.server.sni = "wrong.host.mitmproxy.org"
tssl = SSLTest(server_side=True)
@ -316,7 +316,7 @@ class TestServerTLS:
def test_remote_speaks_no_tls(self, tctx):
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
tctx.server.state = ConnectionState.OPEN
tctx.server.sni = b"example.mitmproxy.org"
tctx.server.sni = "example.mitmproxy.org"
# send ClientHello, receive random garbage back
data = tutils.Placeholder(bytes)
@ -345,7 +345,7 @@ def make_client_tls_layer(
# Add some server config, this is needed anyways.
tctx.server.address = ("example.mitmproxy.org", 443)
tctx.server.sni = b"example.mitmproxy.org"
tctx.server.sni = "example.mitmproxy.org"
tssl_client = SSLTest(**kwargs)
# Start handshake.
@ -446,8 +446,8 @@ class TestClientTLS:
assert tctx.client.tls_established
assert tctx.server.tls_established
assert tctx.server.sni == tctx.client.sni
assert tctx.client.alpn == b"quux"
assert tctx.server.alpn == b"quux"
assert tctx.client.alpn == "quux"
assert tctx.server.alpn == "quux"
_test_echo(playbook, tssl_server, tctx.server)
_test_echo(playbook, tssl_client, tctx.client)

View File

@ -46,10 +46,10 @@ class TestCertStore:
def test_create_explicit(self, tmpdir):
ca = certs.CertStore.from_store(str(tmpdir), "test", 2048)
assert ca.get_cert(b"foo", [])
assert ca.get_cert("foo", [])
ca2 = certs.CertStore.from_store(str(tmpdir), "test", 2048)
assert ca2.get_cert(b"foo", [])
assert ca2.get_cert("foo", [])
assert ca.default_ca.serial == ca2.default_ca.serial
@ -57,51 +57,51 @@ class TestCertStore:
assert tstore.get_cert(None, []).cert.cn is None
def test_sans(self, tstore):
c1 = tstore.get_cert(b"foo.com", [b"*.bar.com"])
tstore.get_cert(b"foo.bar.com", [])
c1 = tstore.get_cert("foo.com", ["*.bar.com"])
tstore.get_cert("foo.bar.com", [])
# assert c1 == c2
c3 = tstore.get_cert(b"bar.com", [])
c3 = tstore.get_cert("bar.com", [])
assert not c1 == c3
def test_sans_change(self, tstore):
tstore.get_cert(b"foo.com", [b"*.bar.com"])
entry = tstore.get_cert(b"foo.bar.com", [b"*.baz.com"])
assert b"*.baz.com" in entry.cert.altnames
tstore.get_cert("foo.com", ["*.bar.com"])
entry = tstore.get_cert("foo.bar.com", ["*.baz.com"])
assert "*.baz.com" in entry.cert.altnames
def test_expire(self, tstore):
tstore.STORE_CAP = 3
tstore.get_cert(b"one.com", [])
tstore.get_cert(b"two.com", [])
tstore.get_cert(b"three.com", [])
tstore.get_cert("one.com", [])
tstore.get_cert("two.com", [])
tstore.get_cert("three.com", [])
assert (b"one.com", ()) in tstore.certs
assert (b"two.com", ()) in tstore.certs
assert (b"three.com", ()) in tstore.certs
assert ("one.com", ()) in tstore.certs
assert ("two.com", ()) in tstore.certs
assert ("three.com", ()) in tstore.certs
tstore.get_cert(b"one.com", [])
tstore.get_cert("one.com", [])
assert (b"one.com", ()) in tstore.certs
assert (b"two.com", ()) in tstore.certs
assert (b"three.com", ()) in tstore.certs
assert ("one.com", ()) in tstore.certs
assert ("two.com", ()) in tstore.certs
assert ("three.com", ()) in tstore.certs
tstore.get_cert(b"four.com", [])
tstore.get_cert("four.com", [])
assert (b"one.com", ()) not in tstore.certs
assert (b"two.com", ()) in tstore.certs
assert (b"three.com", ()) in tstore.certs
assert (b"four.com", ()) in tstore.certs
assert ("one.com", ()) not in tstore.certs
assert ("two.com", ()) in tstore.certs
assert ("three.com", ()) in tstore.certs
assert ("four.com", ()) in tstore.certs
def test_overrides(self, tmp_path):
ca1 = certs.CertStore.from_store(tmp_path / "ca1", "test", 2048)
ca2 = certs.CertStore.from_store(tmp_path / "ca2", "test", 2048)
assert not ca1.default_ca.serial == ca2.default_ca.serial
dc = ca2.get_cert(b"foo.com", [b"sans.example.com"])
dc = ca2.get_cert("foo.com", ["sans.example.com"])
dcp = tmp_path / "dc"
dcp.write_bytes(dc.cert.to_pem())
ca1.add_cert_file("foo.com", dcp)
ret = ca1.get_cert(b"foo.com", [])
ret = ca1.get_cert("foo.com", [])
assert ret.cert.serial == dc.cert.serial
def test_create_dhparams(self, tmp_path):
@ -124,13 +124,13 @@ class TestDummyCert:
r = certs.dummy_cert(
tstore.default_privatekey,
tstore.default_ca._cert,
b"foo.com",
[b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"],
b"Foo Ltd."
"foo.com",
["one.com", "two.com", "*.three.com", "127.0.0.1"],
"Foo Ltd."
)
assert r.cn == b"foo.com"
assert r.altnames == [b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"]
assert r.organization == b"Foo Ltd."
assert r.cn == "foo.com"
assert r.altnames == ["one.com", "two.com", "*.three.com", "127.0.0.1"]
assert r.organization == "Foo Ltd."
r = certs.dummy_cert(
tstore.default_privatekey,
@ -150,14 +150,14 @@ class TestCert:
with open(tdata.path("mitmproxy/net/data/text_cert"), "rb") as f:
d = f.read()
c1 = certs.Cert.from_pem(d)
assert c1.cn == b"google.com"
assert c1.cn == "google.com"
assert len(c1.altnames) == 436
assert c1.organization == b"Google Inc"
assert c1.organization == "Google Inc"
with open(tdata.path("mitmproxy/net/data/text_cert_2"), "rb") as f:
d = f.read()
c2 = certs.Cert.from_pem(d)
assert c2.cn == b"www.inode.co.nz"
assert c2.cn == "www.inode.co.nz"
assert len(c2.altnames) == 2
assert c2.fingerprint()
assert c2.notbefore
@ -197,4 +197,4 @@ class TestCert:
def test_from_store_with_passphrase(self, tdata, tstore):
tstore.add_cert_file("*", Path(tdata.path("mitmproxy/data/mitmproxy.pem")), b"password")
assert tstore.get_cert(b"foo", [])
assert tstore.get_cert("foo", [])

View File

@ -25,12 +25,10 @@ def test_format_keyvals():
("ee", None),
]
)
wrapped = urwid.BoxAdapter(
urwid.ListBox(
urwid.SimpleFocusListWalker(
common.format_keyvals([("foo", "bar")])
)
), 1
wrapped = urwid.Pile(
urwid.SimpleFocusListWalker(
common.format_keyvals([("foo", "bar")])
)
)
assert wrapped.render((30,))
assert common.format_keyvals(