From 9f39e2f387463f3cfcd044b7c7135865e06c506d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 4 Sep 2021 16:03:06 +0200 Subject: [PATCH] tests++ --- docs/scripts/api-events.py | 1 + examples/contrib/tls_passthrough.py | 2 +- mitmproxy/proxy/layers/tls.py | 2 +- test/mitmproxy/proxy/layers/test_tls.py | 43 ++++++++++++++++++++++--- 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/docs/scripts/api-events.py b/docs/scripts/api-events.py index 80d91dae9..462c02531 100644 --- a/docs/scripts/api-events.py +++ b/docs/scripts/api-events.py @@ -124,6 +124,7 @@ with outfile.open("w") as f, contextlib.redirect_stdout(f): tls.TlsClienthelloHook, tls.TlsStartClientHook, tls.TlsStartServerHook, + tls.TlsHandshakeHook, ] ) diff --git a/examples/contrib/tls_passthrough.py b/examples/contrib/tls_passthrough.py index d248bc36b..8652651f2 100644 --- a/examples/contrib/tls_passthrough.py +++ b/examples/contrib/tls_passthrough.py @@ -97,7 +97,7 @@ class MaybeTls: def tls_handshake(self, data: tls.TlsHookData): if isinstance(data.conn, connection.Server): - return + return # we are only interested in failing client connections here. server_address = data.context.server.peername if data.conn.error is None: ctx.log(f"TLS handshake successful: {human.format_address(server_address)}") diff --git a/mitmproxy/proxy/layers/tls.py b/mitmproxy/proxy/layers/tls.py index 86289bfd8..34c7cea7e 100644 --- a/mitmproxy/proxy/layers/tls.py +++ b/mitmproxy/proxy/layers/tls.py @@ -424,7 +424,7 @@ class ClientTLSLayer(_TLSLayer): # we've figured out that we don't want to intercept this connection, so we assign fake connection objects # to all TLS layers. This makes the real connection contents just go through. self.conn = self.tunnel_connection = connection.Client(("ignore-conn", 0), ("ignore-conn", 0), time.time()) - parent_layer = self.context.layers[-2] + parent_layer = self.context.layers[self.context.layers.index(self) - 1] if isinstance(parent_layer, ServerTLSLayer): parent_layer.conn = parent_layer.tunnel_connection = connection.Server(None) self.child_layer = tcp.TCPLayer(self.context, ignore=True) diff --git a/test/mitmproxy/proxy/layers/test_tls.py b/test/mitmproxy/proxy/layers/test_tls.py index 3a8e1a160..4c4aad7d1 100644 --- a/test/mitmproxy/proxy/layers/test_tls.py +++ b/test/mitmproxy/proxy/layers/test_tls.py @@ -1,4 +1,5 @@ import ssl +import time import typing import pytest @@ -459,13 +460,13 @@ class TestClientTLS: << commands.SendData(other_server, b"plaintext") ) - @pytest.mark.parametrize("eager", ["eager", ""]) - def test_server_required(self, tctx, eager): + @pytest.mark.parametrize("server_state", ["open", "closed"]) + def test_server_required(self, tctx, server_state): """ Test the scenario where a server connection is required (for example, because of an unknown ALPN) to establish TLS with the client. """ - if eager: + if server_state == "open": tctx.server.state = ConnectionState.OPEN tssl_server = SSLTest(server_side=True, alpn=["quux"]) playbook, client_layer, tssl_client = make_client_tls_layer(tctx, alpn=["quux"]) @@ -482,7 +483,7 @@ class TestClientTLS: << tls.TlsClienthelloHook(tutils.Placeholder()) >> tutils.reply(side_effect=require_server_conn) ) - if not eager: + if server_state == "closed": ( playbook << commands.OpenConnection(tctx.server) @@ -532,6 +533,40 @@ class TestClientTLS: _test_echo(playbook, tssl_server, tctx.server) _test_echo(playbook, tssl_client, tctx.client) + @pytest.mark.parametrize("server_state", ["open", "closed"]) + def test_passthrough_from_clienthello(self, tctx, server_state): + """ + Test the scenario where the connection is moved to passthrough mode in the tls_clienthello hook. + """ + if server_state == "open": + tctx.server.timestamp_start = time.time() + tctx.server.state = ConnectionState.OPEN + + playbook, client_layer, tssl_client = make_client_tls_layer(tctx, alpn=["quux"]) + + def make_passthrough(client_hello: tls.ClientHelloData) -> None: + client_hello.ignore_connection = True + + client_hello = tssl_client.bio_read() + ( + playbook + >> events.DataReceived(tctx.client, client_hello) + << tls.TlsClienthelloHook(tutils.Placeholder()) + >> tutils.reply(side_effect=make_passthrough) + ) + if server_state == "closed": + ( + playbook + << commands.OpenConnection(tctx.server) + >> tutils.reply(None) + ) + assert ( + playbook + << commands.SendData(tctx.server, client_hello) # passed through unmodified + >> events.DataReceived(tctx.server, b"ServerHello") # and the same for the serverhello. + << commands.SendData(tctx.client, b"ServerHello") + ) + def test_cannot_parse_clienthello(self, tctx: context.Context): """Test the scenario where we cannot parse the ClientHello""" playbook, client_layer, tssl_client = make_client_tls_layer(tctx)