From dfba6e81a655a99edd16803bb3bc7731da8e81de Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 30 Dec 2020 22:52:07 +0100 Subject: [PATCH] alpn: str -> bytes --- mitmproxy/addons/tlsconfig.py | 16 +++++++--------- mitmproxy/io/compat.py | 8 +++----- mitmproxy/net/tls.py | 12 +++++------- mitmproxy/proxy/context.py | 8 +++----- mitmproxy/proxy/layers/http/__init__.py | 12 ++++++------ mitmproxy/proxy/layers/tls.py | 6 +++--- mitmproxy/test/tflow.py | 2 +- mitmproxy/tools/console/flowdetailview.py | 8 ++++---- mitmproxy/tools/web/app.py | 5 +++-- test/mitmproxy/addons/test_tlsconfig.py | 6 +++--- test/mitmproxy/net/test_tls.py | 2 +- test/mitmproxy/proxy/layers/http/test_http2.py | 4 ++-- .../proxy/layers/http/test_http_fuzz.py | 2 +- .../layers/http/test_http_version_interop.py | 2 +- test/mitmproxy/proxy/layers/test_tls.py | 8 ++++---- 15 files changed, 47 insertions(+), 54 deletions(-) diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index dea7b48a2..f1dc90c07 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -8,7 +8,6 @@ from mitmproxy.net import tls as net_tls from mitmproxy.options import CONF_BASENAME from mitmproxy.proxy import context from mitmproxy.proxy.layers import tls -from mitmproxy.utils.strutils import always_bytes # We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default. # https://ssl-config.mozilla.org/#config=old @@ -36,7 +35,7 @@ def alpn_select_callback(conn: SSL.Connection, options: List[bytes]) -> Any: return server_alpn http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS for alpn in options: # client sends in order of preference, so we are nice and respect that. - if alpn.decode(errors="replace") in http_alpns: + if alpn in http_alpns: return alpn else: return SSL.NO_OVERLAPPING_PROTOCOLS @@ -138,7 +137,7 @@ class TlsConfig: ) tls_start.ssl_conn = SSL.Connection(ssl_ctx) tls_start.ssl_conn.set_app_data(AppData( - server_alpn=always_bytes(server.alpn, "utf8", "replace"), + server_alpn=server.alpn, http2=ctx.options.http2, )) tls_start.ssl_conn.set_accept_state() @@ -155,14 +154,13 @@ class TlsConfig: if server.sni is True: server.sni = client.sni or server.address[0] - sni: Optional[bytes] = server.sni.encode("ascii") if server.sni else None if not server.alpn_offers: if client.alpn_offers: if ctx.options.http2: server.alpn_offers = tuple(client.alpn_offers) else: - server.alpn_offers = tuple(x for x in client.alpn_offers if x != "h2") + server.alpn_offers = tuple(x for x in client.alpn_offers if x != b"h2") elif client.tls_established: # We would perfectly support HTTP/1 -> HTTP/2, but we want to keep things on the same protocol version. # There are some edge cases where we want to mirror the regular server's behavior accurately, @@ -172,7 +170,6 @@ class TlsConfig: server.alpn_offers = tls.HTTP_ALPNS else: server.alpn_offers = tls.HTTP1_ALPNS - alpn_offers: List[bytes] = [alpn.encode() for alpn in server.alpn_offers] if not server.cipher_list and ctx.options.ciphers_server: server.cipher_list = ctx.options.ciphers_server.split(":") @@ -195,15 +192,16 @@ class TlsConfig: max_version=net_tls.Version[ctx.options.tls_version_client_max], cipher_list=cipher_list, verify=verify, - sni=sni, + sni=server.sni, ca_path=ctx.options.ssl_verify_upstream_trusted_confdir, ca_pemfile=ctx.options.ssl_verify_upstream_trusted_ca, client_cert=client_cert, - alpn_protos=alpn_offers, + alpn_protos=server.alpn_offers, ) tls_start.ssl_conn = SSL.Connection(ssl_ctx) - tls_start.ssl_conn.set_tlsext_host_name(sni) + if server.sni: + tls_start.ssl_conn.set_tlsext_host_name(server.sni.encode()) tls_start.ssl_conn.set_connect_state() def running(self): diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py index c109f2f25..2f098758f 100644 --- a/mitmproxy/io/compat.py +++ b/mitmproxy/io/compat.py @@ -235,11 +235,9 @@ def convert_10_11(data): def conv_conn(conn): conn["sni"] = strutils.always_str(conn["sni"], "ascii", "backslashreplace") - conn["alpn"] = strutils.always_str(conn.pop("alpn_proto_negotiated"), "utf8", "backslashreplace") - conn["alpn_offers"] = [ - strutils.always_str(alpn, "utf8", "backslashreplace") - for alpn in (conn["alpn_offers"] or []) - ] + conn["alpn"] = conn.pop("alpn_proto_negotiated") + conn["alpn_offers"] = conn["alpn_offers"] or [] + conn["cipher_list"] = conn["cipher_list"] or [] conv_conn(data["client_conn"]) conv_conn(data["server_conn"]) diff --git a/mitmproxy/net/tls.py b/mitmproxy/net/tls.py index 59c0b6edc..5af4aaa26 100644 --- a/mitmproxy/net/tls.py +++ b/mitmproxy/net/tls.py @@ -130,7 +130,7 @@ def create_proxy_server_context( max_version: Version, cipher_list: Optional[Iterable[str]], verify: Verify, - sni: Optional[bytes], + sni: Optional[str], ca_path: Optional[str], ca_pemfile: Optional[str], client_cert: Optional[str], @@ -148,6 +148,7 @@ def create_proxy_server_context( context.set_verify(verify.value, None) if sni is not None: + assert isinstance(sni, str) # Manually enable hostname verification on the context object. # https://wiki.openssl.org/index.php/Hostname_validation param = SSL._lib.SSL_CTX_get0_param(context._context) @@ -158,7 +159,7 @@ def create_proxy_server_context( SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT ) SSL._openssl_assert( - SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni, 0) == 1 + SSL._lib.X509_VERIFY_PARAM_set1_host(param, sni.encode(), 0) == 1 ) if ca_path is None and ca_pemfile is None: @@ -293,14 +294,11 @@ class ClientHello: return None @property - def alpn_protocols(self) -> List[str]: + def alpn_protocols(self) -> List[bytes]: if self._client_hello.extensions: for extension in self._client_hello.extensions.extensions: if extension.type == 0x10: - try: - return [x.name.decode() for x in extension.body.alpn_protocols] - except UnicodeDecodeError: - return [] + return list(x.name for x in extension.body.alpn_protocols) return [] @property diff --git a/mitmproxy/proxy/context.py b/mitmproxy/proxy/context.py index cdce70231..785830b4b 100644 --- a/mitmproxy/proxy/context.py +++ b/mitmproxy/proxy/context.py @@ -56,8 +56,8 @@ class Connection(serializable.Serializable, metaclass=ABCMeta): TLS version, with the exception of the end-entity certificate which MUST be first. """ - alpn: Optional[str] = None - alpn_offers: Sequence[str] = () + alpn: Optional[bytes] = None + alpn_offers: Sequence[bytes] = () # we may want to add SSL_CIPHER_description here, but that's currently not exposed by cryptography cipher: Optional[str] = None @@ -98,9 +98,7 @@ class Connection(serializable.Serializable, metaclass=ABCMeta): @property def alpn_proto_negotiated(self) -> Optional[bytes]: # pragma: no cover warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", DeprecationWarning) - if self.alpn is not None: - return self.alpn.encode() - return None + return self.alpn class Client(Connection): diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index 47f96ea10..dc9d501cc 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -351,7 +351,7 @@ class HttpStream(layer.Layer): yield HttpErrorHook(self.flow) # For HTTP/2 we only want to kill the specific stream, for HTTP/1 we want to kill the connection # *without* sending an HTTP response (that could be achieved by the user by setting flow.response). - if self.context.client.alpn == "h2": + if self.context.client.alpn == b"h2": yield SendHttp(ResponseProtocolError(self.stream_id, "killed"), self.context.client) else: if self.context.client.state & ConnectionState.CAN_WRITE: @@ -532,7 +532,7 @@ class HttpLayer(layer.Layer): self.command_sources = {} http_conn: HttpConnection - if self.context.client.alpn == "h2": + if self.context.client.alpn == b"h2": http_conn = Http2Server(context.fork()) else: http_conn = Http1Server(context.fork()) @@ -606,10 +606,10 @@ class HttpLayer(layer.Layer): for connection in self.connections: # see "tricky multiplexing edge case" in make_http_connection for an explanation conn_is_pending_or_h2 = ( - connection.alpn == "h2" + connection.alpn == b"h2" or connection in self.waiting_for_establishment ) - h2_to_h1 = self.context.client.alpn == "h2" and not conn_is_pending_or_h2 + h2_to_h1 = self.context.client.alpn == b"h2" and not conn_is_pending_or_h2 connection_suitable = ( event.connection_spec_matches(connection) and not h2_to_h1 @@ -679,7 +679,7 @@ class HttpLayer(layer.Layer): # that neither have a content-length specified nor a chunked transfer encoding. # We can't process these two flows to the same h1 connection as they would both have # "read until eof" semantics. The only workaround left is to open a separate connection for each flow. - if not command.err and self.context.client.alpn == "h2" and command.connection.alpn != "h2": + if not command.err and self.context.client.alpn == b"h2" and command.connection.alpn != b"h2": for cmd in waiting[1:]: yield from self.get_connection(cmd, reuse=False) break @@ -695,7 +695,7 @@ class HttpClient(layer.Layer): err = yield commands.OpenConnection(self.context.server) if not err: child_layer: layer.Layer - if self.context.server.alpn == "h2": + if self.context.server.alpn == b"h2": child_layer = Http2Client(self.context) else: child_layer = Http1Client(self.context) diff --git a/mitmproxy/proxy/layers/tls.py b/mitmproxy/proxy/layers/tls.py index 15bd29622..1b2e8cbc6 100644 --- a/mitmproxy/proxy/layers/tls.py +++ b/mitmproxy/proxy/layers/tls.py @@ -91,8 +91,8 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]: return None -HTTP1_ALPNS = ("http/1.1", "http/1.0", "http/0.9") -HTTP_ALPNS = ("h2",) + HTTP1_ALPNS +HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9") +HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS # We need these classes as hooks can only have one argument at the moment. @@ -196,7 +196,7 @@ class _TLSLayer(tunnel.TunnelLayer): all_certs.insert(0, cert) self.conn.timestamp_tls_setup = time.time() - self.conn.alpn = self.tls.get_alpn_proto_negotiated().decode() + self.conn.alpn = self.tls.get_alpn_proto_negotiated() self.conn.certificate_list = [certs.Cert.from_pyopenssl(x) for x in all_certs] self.conn.cipher = self.tls.get_cipher_name() self.conn.tls_version = self.tls.get_protocol_version_name() diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index 366bb3e04..0a0755ad1 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -158,7 +158,7 @@ def tclient_conn() -> context.Client: timestamp_end=946681206, sni="address", cipher_name="cipher", - alpn="http/1.1", + alpn=b"http/1.1", tls_version="TLSv1.2", tls_extensions=[(0x00, bytes.fromhex("000e00000b6578616d"))], state=0, diff --git a/mitmproxy/tools/console/flowdetailview.py b/mitmproxy/tools/console/flowdetailview.py index 5ec419eea..aa3d449c8 100644 --- a/mitmproxy/tools/console/flowdetailview.py +++ b/mitmproxy/tools/console/flowdetailview.py @@ -4,7 +4,7 @@ import urwid import mitmproxy.flow from mitmproxy import http from mitmproxy.tools.console import common, searchable -from mitmproxy.utils import human +from mitmproxy.utils import human, strutils def maybe_timestamp(base, attr): @@ -49,7 +49,7 @@ def flowdetails(state, flow: mitmproxy.flow.Flow): if resp: parts.append(("HTTP Version", resp.http_version)) if sc.alpn: - parts.append(("ALPN", sc.alpn)) + parts.append(("ALPN", strutils.bytes_to_escaped_str(sc.alpn))) text.extend( common.format_keyvals(parts, indent=4) @@ -69,7 +69,7 @@ def flowdetails(state, flow: mitmproxy.flow.Flow): ] if c.altnames: - parts.append(("Alt names", ", ".join(c.altnames))) + parts.append(("Alt names", ", ".join(strutils.bytes_to_escaped_str(x) for x in c.altnames))) text.extend( common.format_keyvals(parts, indent=4) ) @@ -89,7 +89,7 @@ def flowdetails(state, flow: mitmproxy.flow.Flow): if cc.cipher: parts.append(("Cipher Name", cc.cipher)) if cc.alpn: - parts.append(("ALPN", cc.alpn)) + parts.append(("ALPN", strutils.bytes_to_escaped_str(cc.alpn))) text.extend( common.format_keyvals(parts, indent=4) diff --git a/mitmproxy/tools/web/app.py b/mitmproxy/tools/web/app.py index 2150a3234..c24e456c8 100644 --- a/mitmproxy/tools/web/app.py +++ b/mitmproxy/tools/web/app.py @@ -20,6 +20,7 @@ from mitmproxy import io from mitmproxy import log from mitmproxy import optmanager from mitmproxy import version +from mitmproxy.utils.strutils import always_str def flow_to_json(flow: mitmproxy.flow.Flow) -> dict: @@ -48,7 +49,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict: "timestamp_end": flow.client_conn.timestamp_end, "sni": flow.client_conn.sni, "cipher_name": flow.client_conn.cipher, - "alpn_proto_negotiated": flow.client_conn.alpn, + "alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"), "tls_version": flow.client_conn.tls_version, } @@ -60,7 +61,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict: "source_address": flow.server_conn.sockname, "tls_established": flow.server_conn.tls_established, "sni": flow.server_conn.sni, - "alpn_proto_negotiated": flow.server_conn.alpn, + "alpn_proto_negotiated": always_str(flow.client_conn.alpn, "ascii", "backslashreplace"), "tls_version": flow.server_conn.tls_version, "timestamp_start": flow.server_conn.timestamp_start, "timestamp_tcp_setup": flow.server_conn.timestamp_tcp_setup, diff --git a/test/mitmproxy/addons/test_tlsconfig.py b/test/mitmproxy/addons/test_tlsconfig.py index 52bd2e21a..26d0a491b 100644 --- a/test/mitmproxy/addons/test_tlsconfig.py +++ b/test/mitmproxy/addons/test_tlsconfig.py @@ -127,7 +127,7 @@ class TestTlsConfig: ta = tlsconfig.TlsConfig() with taddons.context(ta) as tctx: ctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), tctx.options) - ctx.client.alpn_offers = ["h2"] + ctx.client.alpn_offers = [b"h2"] ctx.client.cipher_list = ["TLS_AES_256_GCM_SHA384", "ECDHE-RSA-AES128-SHA"] ctx.server.address = ("example.mitmproxy.org", 443) @@ -185,8 +185,8 @@ class TestTlsConfig: ta.tls_start(tls_start) assert ctx.server.alpn_offers == expected - assert_alpn(True, tls.HTTP_ALPNS + ("foo",), tls.HTTP_ALPNS + ("foo",)) - assert_alpn(False, tls.HTTP_ALPNS + ("foo",), tls.HTTP1_ALPNS + ("foo",)) + assert_alpn(True, tls.HTTP_ALPNS + (b"foo",), tls.HTTP_ALPNS + (b"foo",)) + assert_alpn(False, tls.HTTP_ALPNS + (b"foo",), tls.HTTP1_ALPNS + (b"foo",)) assert_alpn(True, [], tls.HTTP_ALPNS) assert_alpn(False, [], tls.HTTP1_ALPNS) ctx.client.timestamp_tls_setup = time.time() diff --git a/test/mitmproxy/net/test_tls.py b/test/mitmproxy/net/test_tls.py index d9f794ef2..8833743db 100644 --- a/test/mitmproxy/net/test_tls.py +++ b/test/mitmproxy/net/test_tls.py @@ -110,7 +110,7 @@ class TestClientHello: 49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161, 49171, 49162, 49172, 156, 157, 47, 53, 10 ] - assert c.alpn_protocols == ['h2', 'http/1.1'] + assert c.alpn_protocols == [b'h2', b'http/1.1'] assert c.extensions == [ (65281, b'\x00'), (0, b'\x00\x0e\x00\x00\x0bexample.com'), diff --git a/test/mitmproxy/proxy/layers/http/test_http2.py b/test/mitmproxy/proxy/layers/http/test_http2.py index afa1e530b..9334cad76 100644 --- a/test/mitmproxy/proxy/layers/http/test_http2.py +++ b/test/mitmproxy/proxy/layers/http/test_http2.py @@ -44,7 +44,7 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]: def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]: - tctx.client.alpn = "h2" + tctx.client.alpn = b"h2" frame_factory = FrameFactory() playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) @@ -58,7 +58,7 @@ def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]: def make_h2(open_connection: OpenConnection) -> None: - open_connection.connection.alpn = "h2" + open_connection.connection.alpn = b"h2" def test_simple(tctx): diff --git a/test/mitmproxy/proxy/layers/http/test_http_fuzz.py b/test/mitmproxy/proxy/layers/http/test_http_fuzz.py index e4d836b27..93d4ffd26 100644 --- a/test/mitmproxy/proxy/layers/http/test_http_fuzz.py +++ b/test/mitmproxy/proxy/layers/http/test_http_fuzz.py @@ -208,7 +208,7 @@ def h2_frames(draw): def h2_layer(opts): tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), opts) - tctx.client.alpn = "h2" + tctx.client.alpn = b"h2" layer = http.HttpLayer(tctx, HTTPMode.regular) for _ in layer.handle_event(Start()): diff --git a/test/mitmproxy/proxy/layers/http/test_http_version_interop.py b/test/mitmproxy/proxy/layers/http/test_http_version_interop.py index cbca27a40..945029677 100644 --- a/test/mitmproxy/proxy/layers/http/test_http_version_interop.py +++ b/test/mitmproxy/proxy/layers/http/test_http_version_interop.py @@ -22,7 +22,7 @@ def event_types(events): def h2_client(tctx: Context) -> Tuple[h2.connection.H2Connection, Playbook]: - tctx.client.alpn = "h2" + tctx.client.alpn = b"h2" playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) conn = h2.connection.H2Connection() diff --git a/test/mitmproxy/proxy/layers/test_tls.py b/test/mitmproxy/proxy/layers/test_tls.py index 84cf66e94..925896a09 100644 --- a/test/mitmproxy/proxy/layers/test_tls.py +++ b/test/mitmproxy/proxy/layers/test_tls.py @@ -188,7 +188,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut tls_start.ssl_conn = SSL.Connection(ssl_context) tls_start.ssl_conn.set_connect_state() # Set SNI - tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode("ascii")) + tls_start.ssl_conn.set_tlsext_host_name(tls_start.conn.sni.encode()) # Manually enable hostname verification. # Recent OpenSSL versions provide slightly nicer ways to do this, but they are not exposed in @@ -202,7 +202,7 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut SSL._lib.X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS | SSL._lib.X509_CHECK_FLAG_NEVER_CHECK_SUBJECT ) SSL._openssl_assert( - SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode("ascii"), 0) == 1 + SSL._lib.X509_VERIFY_PARAM_set1_host(param, tls_start.conn.sni.encode(), 0) == 1 ) return tutils.reply(*args, side_effect=make_conn, **kwargs) @@ -446,8 +446,8 @@ class TestClientTLS: assert tctx.client.tls_established assert tctx.server.tls_established assert tctx.server.sni == tctx.client.sni - assert tctx.client.alpn == "quux" - assert tctx.server.alpn == "quux" + assert tctx.client.alpn == b"quux" + assert tctx.server.alpn == b"quux" _test_echo(playbook, tssl_server, tctx.server) _test_echo(playbook, tssl_client, tctx.client)