From 1c80dfe17f07855e3257d699d9242d93c6a90388 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 11 Nov 2019 18:32:01 +0100 Subject: [PATCH] [sans-io] tls layer++ --- mitmproxy/addons/tlsconfig.py | 138 ++++++++++ mitmproxy/proxy2/context.py | 4 + mitmproxy/proxy2/layer.py | 7 +- mitmproxy/proxy2/layers/http/http.py | 5 +- mitmproxy/proxy2/layers/tls.py | 105 +++----- test/mitmproxy/proxy2/layers/test_tcp.py | 6 +- test/mitmproxy/proxy2/layers/test_tls.py | 311 ++++++++++++----------- test/mitmproxy/proxy2/test_layer.py | 10 +- test/mitmproxy/proxy2/test_tutils.py | 4 +- test/mitmproxy/proxy2/tutils.py | 35 +-- 10 files changed, 371 insertions(+), 254 deletions(-) create mode 100644 mitmproxy/addons/tlsconfig.py diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py new file mode 100644 index 000000000..b0625d73c --- /dev/null +++ b/mitmproxy/addons/tlsconfig.py @@ -0,0 +1,138 @@ +import os +from typing import Optional, Tuple + +from OpenSSL import SSL, crypto + +from mitmproxy import certs, ctx, exceptions +from mitmproxy.net import tls as net_tls +from mitmproxy.options import CONF_BASENAME +from mitmproxy.proxy.protocol.tls import CIPHER_ID_NAME_MAP, DEFAULT_CLIENT_CIPHERS +from mitmproxy.proxy2 import context +from mitmproxy.proxy2.layers import tls + + +def alpn_select_callback(conn: SSL.Connection, options): + server_alpn = conn.get_app_data()["server_alpn"] + if server_alpn and server_alpn in options: + return server_alpn + for alpn in tls.HTTP_ALPNS: + if alpn in options: + return alpn + else: + # FIXME: pyOpenSSL requires that an ALPN is negotiated, we can't return SSL_TLSEXT_ERR_NOACK. + return options[0] + + +class TlsConfig: + certstore: certs.CertStore + + # TODO: We should re-use SSL.Context options here, if only for TLS session resumption. + # This may require patches to pyOpenSSL, as some functionality is only exposed on contexts. + + def get_cert(self, context: context.Context) -> Tuple[certs.Cert, SSL.PKey, str]: + return self.certstore.get_cert( + context.client.sni, [context.client.sni] + ) + + def tls_start(self, tls_start: tls.TlsStart): + if tls_start.conn == tls_start.context.client: + self.create_client_proxy_ssl_conn(tls_start) + else: + self.create_proxy_server_ssl_conn(tls_start) + + def create_client_proxy_ssl_conn(self, tls_start: tls.TlsStart) -> None: + tls_method, tls_options = net_tls.VERSION_CHOICES[ctx.options.ssl_version_client] + cert, key, chain_file = self.get_cert(tls_start.context) + if ctx.options.add_upstream_certs_to_client_chain: + raise NotImplementedError() + else: + extra_chain_certs = None + ssl_ctx = net_tls.create_server_context( + cert=cert, + key=key, + method=tls_method, + options=tls_options, + cipher_list=ctx.options.ciphers_client or DEFAULT_CLIENT_CIPHERS, + dhparams=self.certstore.dhparams, + chain_file=chain_file, + alpn_select_callback=alpn_select_callback, + extra_chain_certs=extra_chain_certs, + ) + tls_start.ssl_conn = SSL.Connection(ssl_ctx) + tls_start.ssl_conn.set_app_data({ + "server_alpn": tls_start.context.server.alpn + }) + + def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStart) -> None: + client = tls_start.context.client + server: context.Server = tls_start.conn + + if server.sni is True: + server.sni = client.sni or server.address[0].encode() + + if not server.alpn_offers: + if client.alpn: + server.alpn_offers = [client.alpn] + elif client.alpn_offers: + server.alpn_offers = client.alpn_offers + + # We pass through the list of ciphers send by the client, because some HTTP/2 servers + # will select a non-HTTP/2 compatible cipher from our default list and then hang up + # because it's incompatible with h2. + if not server.cipher_list: + if ctx.options.ciphers_server: + server.cipher_list = ctx.options.ciphers_server.split(":") + elif client.cipher_list: + server.cipher_list = [ + x for x in client.cipher_list + if x in CIPHER_ID_NAME_MAP + ] + + args = net_tls.client_arguments_from_options(ctx.options) + + client_certs = args.pop("client_certs") + client_cert: Optional[str] = None + if client_certs: + client_certs = os.path.expanduser(client_certs) + if os.path.isfile(client_certs): + client_cert = client_certs + else: + server_name: str = (server.sni or server.address[0].encode("idna")).decode() + path = os.path.join(client_certs, f"{server_name}.pem") + if os.path.exists(path): + client_cert = path + + args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None + ssl_ctx = net_tls.create_client_context( + cert=client_cert, + sni=server.sni.decode("idna"), # FIXME: Should pass-through here. + alpn_protos=server.alpn_offers, + **args + ) + tls_start.ssl_conn = SSL.Connection(ssl_ctx) + + def configure(self, updated): + if not any(x in updated for x in ["confdir", "certs"]): + return + + certstore_path = os.path.expanduser(ctx.options.confdir) + if not os.path.exists(os.path.dirname(certstore_path)): + raise exceptions.OptionsError( + f"Certificate Authority parent directory does not exist: {os.path.dirname(certstore_path)}") + self.certstore = certs.CertStore.from_store( + path=certstore_path, + basename=CONF_BASENAME, + key_size=ctx.options.key_size + ) + for certspec in ctx.options.certs: + parts = certspec.split("=", 1) + if len(parts) == 1: + parts = ["*", parts[0]] + + cert = os.path.expanduser(parts[1]) + if not os.path.exists(cert): + raise exceptions.OptionsError(f"Certificate file does not exist: {cert}") + try: + self.certstore.add_cert_file(parts[0], cert) + except crypto.Error as e: + raise exceptions.OptionsError(f"Invalid certificate format: {cert}") from e diff --git a/mitmproxy/proxy2/context.py b/mitmproxy/proxy2/context.py index 1574cc5b8..176d32c02 100644 --- a/mitmproxy/proxy2/context.py +++ b/mitmproxy/proxy2/context.py @@ -21,8 +21,12 @@ class Connection: tls_established: bool = False alpn: Optional[bytes] = None alpn_offers: Sequence[bytes] = () + cipher_list: Sequence[bytes] = () + tls_version: Optional[str] = None sni: Union[bytes, bool, None] + timestamp_tls_setup: Optional[float] = None + @property def connected(self): return self.state is ConnectionState.OPEN diff --git a/mitmproxy/proxy2/layer.py b/mitmproxy/proxy2/layer.py index 974968350..0e4da3c88 100644 --- a/mitmproxy/proxy2/layer.py +++ b/mitmproxy/proxy2/layer.py @@ -67,7 +67,7 @@ class Layer: @abstractmethod def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: """Handle a proxy server event""" - yield from () + yield from () # pragma: no cover def handle_event(self, event: events.Event) -> commands.TCommandGenerator: if self._paused: @@ -95,10 +95,7 @@ class Layer: processing any other commands. """ try: - if isinstance(send, Exception): - command = command_generator.throw(type(send), send) - else: - command = command_generator.send(send) + command = command_generator.send(send) except StopIteration: return diff --git a/mitmproxy/proxy2/layers/http/http.py b/mitmproxy/proxy2/layers/http/http.py index b0eadcdd7..9bb5e79c0 100644 --- a/mitmproxy/proxy2/layers/http/http.py +++ b/mitmproxy/proxy2/layers/http/http.py @@ -13,7 +13,7 @@ from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy2 import commands, events from mitmproxy.proxy2.context import Client, Connection, Context, Server from mitmproxy.proxy2.layer import Layer, NextLayer -from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply +from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS from mitmproxy.proxy2.utils import expect from mitmproxy.utils import human @@ -676,6 +676,9 @@ class HTTPLayer(Layer): def make_http_connection(self, connection: Server) -> None: if connection.tls and not connection.tls_established: + connection.alpn_offers = list(HTTP_ALPNS) + if not self.context.options.http2: + connection.alpn_offers.remove(b"h2") new_command = EstablishServerTLS(connection) new_command.blocking = object() yield new_command diff --git a/mitmproxy/proxy2/layers/tls.py b/mitmproxy/proxy2/layers/tls.py index 132fd6e00..1675ad379 100644 --- a/mitmproxy/proxy2/layers/tls.py +++ b/mitmproxy/proxy2/layers/tls.py @@ -1,12 +1,10 @@ -import os import struct +import time from typing import Any, Dict, Generator, Iterator, Optional, Tuple from OpenSSL import SSL -from mitmproxy.certs import CertStore -from mitmproxy.net.tls import ClientHello -from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS +from mitmproxy.net import tls as net_tls from mitmproxy.proxy2 import commands, events, layer from mitmproxy.proxy2 import context from mitmproxy.proxy2.utils import expect @@ -69,7 +67,7 @@ def get_client_hello(data: bytes) -> Optional[bytes]: return None -def parse_client_hello(data: bytes) -> Optional[ClientHello]: +def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]: """ Check if the supplied bytes contain a full ClientHello message, and if so, parse it. @@ -84,10 +82,13 @@ def parse_client_hello(data: bytes) -> Optional[ClientHello]: # Check if ClientHello is complete client_hello = get_client_hello(data) if client_hello: - return ClientHello(client_hello[4:]) + return net_tls.ClientHello(client_hello[4:]) return None +HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9") + + class EstablishServerTLS(commands.ConnectionCommand): connection: context.Server blocking = True @@ -99,9 +100,17 @@ class EstablishServerTLSReply(events.CommandReply): """error message""" +class TlsStart: + def __init__(self, conn: context.Connection, context: context.Context) -> None: + self.conn = conn + self.context = context + self.ssl_conn = None + + class _TLSLayer(layer.Layer): tls: Dict[context.Connection, SSL.Connection] child_layer: layer.Layer + ssl_context: Optional[SSL.Context] = None def __init__(self, context: context.Context): super().__init__(context) @@ -140,15 +149,18 @@ class _TLSLayer(layer.Layer): except SSL.WantReadError: yield from self.tls_interact(conn) return False, None - except SSL.ZeroReturnError as e: + except SSL.Error as e: return False, repr(e) else: conn.tls_established = True + conn.sni = self.tls[conn].get_servername() conn.alpn = self.tls[conn].get_alpn_proto_negotiated() + conn.cipher_list = self.tls[conn].get_cipher_list() + conn.tls_version = self.tls[conn].get_protocol_version_name() + conn.timestamp_tls_setup = time.time() yield commands.Log(f"TLS established: {conn}") yield from self.receive(conn, b"") # TODO: Set all other connection attributes here - # there might already be data in the OpenSSL BIO, so we need to trigger its processing. return True, None def receive(self, conn: context.Connection, data: bytes): @@ -213,8 +225,8 @@ class ServerTLSLayer(_TLSLayer): self.command_to_reply_to = {} self.child_layer = layer.NextLayer(self.context) - def negotiate(self, conn: context.Connection, data: bytes) -> Generator[ - commands.Command, Any, Tuple[bool, Optional[str]]]: + def negotiate(self, conn: context.Connection, data: bytes) \ + -> Generator[commands.Command, Any, Tuple[bool, Optional[str]]]: done, err = yield from super().negotiate(conn, data) if done or err: cmd = self.command_to_reply_to.pop(conn) @@ -232,19 +244,11 @@ class ServerTLSLayer(_TLSLayer): def start_server_tls(self, conn: context.Server): assert conn not in self.tls assert conn.connected + conn.tls = True - ssl_context = SSL.Context(SSL.SSLv23_METHOD) - if conn.alpn_offers: - ssl_context.set_alpn_protos(conn.alpn_offers) - self.tls[conn] = SSL.Connection(ssl_context) - - if conn.sni: - if conn.sni is True: - if self.context.client.sni: - conn.sni = self.context.client.sni - else: - conn.sni = conn.address[0].encode() - self.tls[conn].set_tlsext_host_name(conn.sni) + tls_start = TlsStart(conn, self.context) + yield commands.Hook("tls_start", tls_start) + self.tls[conn] = tls_start.ssl_conn self.tls[conn].set_connect_state() yield from self.negotiate(conn, b"") @@ -274,6 +278,7 @@ class ClientTLSLayer(_TLSLayer): super().__init__(context) self.recv_buffer = bytearray() self.child_layer = layer.NextLayer(self.context) + self._handle_event = self.state_start @expect(events.Start) def state_start(self, _) -> commands.TCommandGenerator: @@ -281,9 +286,6 @@ class ClientTLSLayer(_TLSLayer): self._handle_event = self.state_wait_for_clienthello yield from () - _handle_event = state_start - - @expect(events.DataReceived, events.ConnectionClosed) def state_wait_for_clienthello(self, event: events.Event): client = self.context.client if isinstance(event, events.DataReceived) and event.connection == client: @@ -296,8 +298,7 @@ class ClientTLSLayer(_TLSLayer): if client_hello: yield commands.Log(f"Client Hello: {client_hello}") - # TODO: Don't do double conversion - client.sni = client_hello.sni.encode("idna") + client.sni = client_hello.sni client.alpn_offers = client_hello.alpn_protocols client_tls_requires_server_connection = ( @@ -322,8 +323,10 @@ class ClientTLSLayer(_TLSLayer): # In any case, we now have enough information to start server TLS if needed. yield from self.event_to_child(events.Start()) + elif isinstance(event, events.ConnectionClosed) and event.connection == client: + self.recv_buffer.clear() else: - raise NotImplementedError(event) # TODO + yield from self.event_to_child(event) def start_server_tls(self): """ @@ -339,11 +342,6 @@ class ClientTLSLayer(_TLSLayer): ) return err - server.alpn_offers = [ - x for x in self.context.client.alpn_offers - if not (x.startswith(b"h2-") or x.startswith(b"spdy")) - ] - err = yield EstablishServerTLS(server) if err: yield commands.Log( @@ -352,36 +350,10 @@ class ClientTLSLayer(_TLSLayer): return err def start_client_tls(self) -> commands.TCommandGenerator: - # FIXME: Do this properly. Also adjust error message in negotiate() client = self.context.client - server = self.context.server - context = SSL.Context(SSL.SSLv23_METHOD) - cert, privkey, cert_chain = CertStore.from_store( - os.path.expanduser("~/.mitmproxy"), "mitmproxy", - self.context.options.key_size - ).get_cert(client.sni, (client.sni,)) - context.use_privatekey(privkey) - context.use_certificate(cert.x509) - context.set_cipher_list(DEFAULT_CLIENT_CIPHERS) - - def alpn_select_callback(conn_, options): - if server.alpn in options: - return server.alpn - elif b"h2" in options: - return b"h2" - elif b"http/1.1" in options: - return b"http/1.1" - elif b"http/1.0" in options: - return b"http/1.0" - elif b"http/0.9" in options: - return b"http/0.9" - else: - # FIXME: We MUST return something here. At this point we are at loss. - return options[0] - - context.set_alpn_select_callback(alpn_select_callback) - - self.tls[client] = SSL.Connection(context) + tls_start = TlsStart(client, self.context) + yield commands.Hook("tls_start", tls_start) + self.tls[client] = tls_start.ssl_conn self.tls[client].set_accept_state() yield from self.negotiate(client, bytes(self.recv_buffer)) @@ -390,11 +362,16 @@ class ClientTLSLayer(_TLSLayer): def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]: done, err = yield from super().negotiate(conn, data) if err: + if self.context.client.sni: + # TODO: Also use other sources than SNI + dest = " for " + self.context.client.sni.decode("idna") + else: + dest = "" yield commands.Log( f"Client TLS Handshake failed. " - f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).", + f"The client may not trust the proxy's certificate{dest} ({err}).", level="warn" - # TODO: Also use other sources than SNI + ) yield commands.CloseConnection(self.context.client) return done diff --git a/test/mitmproxy/proxy2/layers/test_tcp.py b/test/mitmproxy/proxy2/layers/test_tcp.py index e6fe97a23..5217ab547 100644 --- a/test/mitmproxy/proxy2/layers/test_tcp.py +++ b/test/mitmproxy/proxy2/layers/test_tcp.py @@ -24,7 +24,7 @@ def test_open_connection(tctx): def test_open_connection_err(tctx): f = Placeholder() assert ( - playbook(TCPLayer(tctx), hooks=True) + playbook(TCPLayer(tctx)) << Hook("tcp_start", f) >> reply() << OpenConnection(tctx.server) @@ -40,7 +40,7 @@ def test_simple(tctx): f = Placeholder() assert ( - playbook(TCPLayer(tctx), hooks=True) + playbook(TCPLayer(tctx)) << Hook("tcp_start", f) >> reply() << OpenConnection(tctx.server) @@ -71,7 +71,7 @@ def test_receive_data_before_server_connected(tctx): will still be forwarded. """ assert ( - playbook(TCPLayer(tctx)) + playbook(TCPLayer(tctx), hooks=False) << OpenConnection(tctx.server) >> DataReceived(tctx.client, b"hello!") >> reply(None, to=-2) diff --git a/test/mitmproxy/proxy2/layers/test_tls.py b/test/mitmproxy/proxy2/layers/test_tls.py index efe7836dd..8f0da853f 100644 --- a/test/mitmproxy/proxy2/layers/test_tls.py +++ b/test/mitmproxy/proxy2/layers/test_tls.py @@ -3,11 +3,15 @@ import ssl import typing import pytest +from OpenSSL import SSL -from mitmproxy.proxy2 import context, events, commands +from mitmproxy.proxy2 import commands, context, events from mitmproxy.proxy2.layers import tls +from mitmproxy.utils import data from test.mitmproxy.proxy2 import tutils +tlsdata = data.Data(__name__) + def test_is_tls_handshake_record(): assert tls.is_tls_handshake_record(bytes.fromhex("160300")) @@ -33,11 +37,11 @@ def test_record_contents(): def test_record_contents_err(): - with pytest.raises(ValueError, msg="Expected TLS record"): + with pytest.raises(ValueError, match="Expected TLS record"): next(tls.handshake_record_contents(b"GET /error")) empty_record = bytes.fromhex("1603010000") - with pytest.raises(ValueError, msg="Record must not be empty"): + with pytest.raises(ValueError, match="Record must not be empty"): next(tls.handshake_record_contents(empty_record)) @@ -53,8 +57,8 @@ def test_get_client_hello(): assert tls.get_client_hello(single_record) == client_hello_no_extensions split_over_two_records = ( - bytes.fromhex("1603010020") + client_hello_no_extensions[:32] + - bytes.fromhex("1603010045") + client_hello_no_extensions[32:] + bytes.fromhex("1603010020") + client_hello_no_extensions[:32] + + bytes.fromhex("1603010045") + client_hello_no_extensions[32:] ) assert tls.get_client_hello(split_over_two_records) == client_hello_no_extensions @@ -65,7 +69,8 @@ def test_get_client_hello(): class SSLTest: """Helper container for Python's builtin SSL object.""" - def __init__(self, server_side=False, alpn=None): + def __init__(self, server_side: bool = False, alpn: typing.List[bytes] = None, + sni: typing.Optional[bytes] = b"example.com"): self.inc = ssl.MemoryBIO() self.out = ssl.MemoryBIO() self.ctx = ssl.SSLContext() @@ -77,83 +82,78 @@ class SSLTest: self.obj = self.ctx.wrap_bio( self.inc, self.out, - server_hostname=None if server_side else "example.com", + server_hostname=None if server_side else sni, server_side=server_side, ) -def _test_tls_client_server( - tctx: context.Context, - alpn: typing.Optional[str] -) -> typing.Tuple[tutils.playbook[tls.ClientTLSLayer], SSLTest]: - layer = tls.ClientTLSLayer(tctx) - playbook = tutils.playbook(layer) - tctx.server.tls = True - tctx.server.address = ("example.com", 443) - tssl_client = SSLTest(alpn=alpn) - - # Handshake - assert ( - playbook - << None - ) - - with pytest.raises(ssl.SSLWantReadError): - tssl_client.obj.do_handshake() - client_hello = tssl_client.out.read() - assert ( - playbook - >> events.DataReceived(tctx.client, client_hello[:42]) - << None - ) - # Still waiting... - # Finish sending ClientHello - playbook >> events.DataReceived(tctx.client, client_hello[42:]) - return playbook, tssl_client - - -def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None: +def _test_echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None: tssl.obj.write(b"Hello World") data = tutils.Placeholder() assert ( - playbook - >> events.DataReceived(conn, tssl.out.read()) - << commands.Hook("next_layer", tutils.Placeholder()) - >> tutils.next_layer(tutils.EchoLayer) - << commands.SendData(conn, data) + playbook + >> events.DataReceived(conn, tssl.out.read()) + << commands.SendData(conn, data) ) tssl.inc.write(data()) assert tssl.obj.read() == b"hello world" +class TlsEchoLayer(tutils.EchoLayer): + err: typing.Optional[str] = None + + def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: + if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls": + # noinspection PyTypeChecker + self.err = yield tls.EstablishServerTLS(self.context.server) + else: + yield from super()._handle_event(event) + + +def interact(playbook: tutils.playbook, conn: context.Connection, tssl: SSLTest): + data = tutils.Placeholder() + assert ( + playbook + >> events.DataReceived(conn, tssl.out.read()) + << commands.SendData(conn, data) + ) + tssl.inc.write(data()) + + +def reply_tls_start(*args, **kwargs) -> tutils.reply: + """ + Helper function to simplify the syntax for tls_start hooks. + """ + + def make_conn(hook: commands.Hook) -> None: + tls_start = hook.data + assert isinstance(tls_start, tls.TlsStart) + ssl_context = SSL.Context(SSL.SSLv23_METHOD) + if tls_start.conn == tls_start.context.client: + ssl_context.use_privatekey_file( + tlsdata.path("../../net/data/verificationcerts/trusted-leaf.key") + ) + ssl_context.use_certificate_chain_file( + tlsdata.path("../../net/data/verificationcerts/trusted-leaf.crt") + ) + tls_start.ssl_conn = SSL.Connection(ssl_context) + + return tutils.reply(*args, side_effect=make_conn, **kwargs) + + class TestServerTLS: def test_no_tls(self, tctx: context.Context): """Test TLS layer without TLS""" layer = tls.ServerTLSLayer(tctx) - playbook = tutils.playbook(layer) + layer.child_layer = TlsEchoLayer(tctx) # Handshake assert ( - playbook - >> events.DataReceived(tctx.client, b"Hello World") - << commands.Hook("next_layer", tutils.Placeholder()) - >> tutils.next_layer(tutils.EchoLayer) - << commands.SendData(tctx.client, b"hello world") - ) - - def test_no_connection(self, tctx): - """ - The server TLS layer is initiated, but there is no active connection yet, so nothing - should be done. - """ - layer = tls.ServerTLSLayer(tctx) - playbook = tutils.playbook(layer) - tctx.server.tls = True - - # We did not have a server connection before, so let's do nothing. - assert ( - playbook - << None + tutils.playbook(layer) + >> events.DataReceived(tctx.client, b"Hello World") + << commands.SendData(tctx.client, b"hello world") + >> events.DataReceived(tctx.server, b"Foo") + << commands.SendData(tctx.server, b"foo") ) def test_simple(self, tctx): @@ -161,166 +161,171 @@ class TestServerTLS: playbook = tutils.playbook(layer) tctx.server.connected = True tctx.server.address = ("example.com", 443) - tctx.server.tls = True tssl = SSLTest(server_side=True) # send ClientHello data = tutils.Placeholder() assert ( - playbook - << commands.SendData(tctx.server, data) + playbook + >> events.DataReceived(tctx.client, b"establish-server-tls") + << commands.Hook("next_layer", tutils.Placeholder()) + >> tutils.next_layer(TlsEchoLayer) + << commands.Hook("tls_start", tutils.Placeholder()) + >> reply_tls_start() + << commands.SendData(tctx.server, data) ) # receive ServerHello, finish client handshake tssl.inc.write(data()) with pytest.raises(ssl.SSLWantReadError): tssl.obj.do_handshake() - data = tutils.Placeholder() - assert ( - playbook - >> events.DataReceived(tctx.server, tssl.out.read()) - << commands.SendData(tctx.server, data) - ) - tssl.inc.write(data()) + interact(playbook, tctx.server, tssl) # finish server handshake tssl.obj.do_handshake() assert ( - playbook - >> events.DataReceived(tctx.server, tssl.out.read()) - << None + playbook + >> events.DataReceived(tctx.server, tssl.out.read()) + << None ) assert tctx.server.tls_established - assert tctx.server.sni == b"example.com" # Echo - echo(playbook, tssl, tctx.server) + _test_echo(playbook, tssl, tctx.server) + + +def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer]: + # This is a bit contrived as the client layer expects a server layer as parent. + # We also set child layers manually to avoid NextLayer noise. + server_layer = tls.ServerTLSLayer(tctx) + client_layer = tls.ClientTLSLayer(tctx) + server_layer.child_layer = client_layer + client_layer.child_layer = TlsEchoLayer(tctx) + playbook = tutils.playbook(server_layer) + return playbook, client_layer + + +def _test_tls_client_server( + tctx: context.Context, + sni: typing.Optional[bytes] +) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer, SSLTest]: + playbook, client_layer = _make_client_tls_layer(tctx) + tctx.server.tls = True + tctx.server.address = ("example.com", 443) + tssl_client = SSLTest(sni=sni) + + # Send ClientHello + with pytest.raises(ssl.SSLWantReadError): + tssl_client.obj.do_handshake() + + return playbook, client_layer, tssl_client class TestClientTLS: - def test_simple(self, tctx: context.Context): + def test_client_only(self, tctx: context.Context): """Test TLS with client only""" - layer = tls.ClientTLSLayer(tctx) - playbook = tutils.playbook(layer) + playbook, client_layer = _make_client_tls_layer(tctx) tssl = SSLTest() + assert not tctx.client.tls_established - # Handshake - assert playbook - assert layer._handle_event == layer.state_wait_for_clienthello - - def interact(): - data = tutils.Placeholder() - assert ( - playbook - >> events.DataReceived(tctx.client, tssl.out.read()) - << commands.SendData(tctx.client, data) - ) - tssl.inc.write(data()) - try: - tssl.obj.do_handshake() - except ssl.SSLWantReadError: - return False - else: - return True - - # receive ClientHello, send ServerHello + # Start Handshake, send ClientHello and ServerHello with pytest.raises(ssl.SSLWantReadError): tssl.obj.do_handshake() - assert not interact() - # Finish Handshake - assert interact() + data = tutils.Placeholder() + assert ( + playbook + >> events.DataReceived(tctx.client, tssl.out.read()) + << commands.Hook("tls_start", tutils.Placeholder()) + >> reply_tls_start() + << commands.SendData(tctx.client, data) + ) + tssl.inc.write(data()) tssl.obj.do_handshake() + # Finish Handshake + interact(playbook, tctx.client, tssl) - assert layer._handle_event == layer.state_process + assert tssl.obj.getpeercert(True) + assert tctx.client.tls_established # Echo - echo(playbook, tssl, tctx.client) + _test_echo(playbook, tssl, tctx.client) assert ( - playbook - >> events.DataReceived(tctx.server, b"Hello") - << commands.SendData(tctx.server, b"hello") + playbook + >> events.DataReceived(tctx.server, b"Plaintext") + << commands.SendData(tctx.server, b"plaintext") ) - def test_no_server_conn_required(self, tctx): + def test_server_not_required(self, tctx): """ Here we test the scenario where a server connection is _not_ required to establish TLS with the client. After determining this when parsing the ClientHello, we only establish a connection with the client. The server connection may ultimately be established when OpenConnection is called. """ - playbook, _ = _test_tls_client_server(tctx, None) + playbook, client_layer, tssl = _test_tls_client_server(tctx, sni=b"example.com") data = tutils.Placeholder() assert ( - playbook - << commands.SendData(tctx.client, data) + playbook + >> events.DataReceived(tctx.client, tssl.out.read()) + << commands.Hook("tls_start", tutils.Placeholder()) + >> reply_tls_start() + << commands.SendData(tctx.client, data) ) - assert data() - assert playbook.layer._handle_event == playbook.layer.state_process + tssl.inc.write(data()) + tssl.obj.do_handshake() + interact(playbook, tctx.client, tssl) + assert tctx.client.tls_established - def test_alpn(self, tctx): + def test_server_required(self, tctx): """ - Here we test the scenario where a server connection is required (e.g. because of ALPN negotation) + Here we test the scenario where a server connection is required (because SNI is missing) to establish TLS with the client. """ - tssl_server = SSLTest(server_side=True, alpn=["foo", "bar"]) - - playbook, tssl_client = _test_tls_client_server(tctx, ["qux", "foo"]) + tssl_server = SSLTest(server_side=True) + playbook, client_layer, tssl_client = _test_tls_client_server(tctx, sni=None) # We should now get instructed to open a server connection. - assert ( - playbook - << commands.OpenConnection(tctx.server) - ) - tctx.server.connected = True data = tutils.Placeholder() assert ( - playbook - >> events.OpenConnectionReply(-1, None) - << commands.SendData(tctx.server, data) + playbook + >> events.DataReceived(tctx.client, tssl_client.out.read()) + << commands.OpenConnection(tctx.server) + >> tutils.reply(None) + << commands.Hook("tls_start", tutils.Placeholder()) + >> reply_tls_start() + << commands.SendData(tctx.server, data) ) - assert playbook.layer._handle_event == playbook.layer.state_wait_for_server_tls - assert playbook.layer.child_layer.tls[tctx.server] # Establish TLS with the server... tssl_server.inc.write(data()) with pytest.raises(ssl.SSLWantReadError): tssl_server.obj.do_handshake() + data = tutils.Placeholder() assert ( - playbook - >> events.DataReceived(tctx.server, tssl_server.out.read()) - << commands.SendData(tctx.server, data) + playbook + >> events.DataReceived(tctx.server, tssl_server.out.read()) + << commands.SendData(tctx.server, data) + << commands.Hook("tls_start", tutils.Placeholder()) ) tssl_server.inc.write(data()) - tssl_server.obj.do_handshake() - data = tutils.Placeholder() - assert ( - playbook - >> events.DataReceived(tctx.server, tssl_server.out.read()) - << commands.SendData(tctx.client, data) - ) - - assert playbook.layer._handle_event == playbook.layer.state_process assert tctx.server.tls_established - # Server TLS is established, we can now reply to the client handshake... - tssl_client.inc.write(data()) - with pytest.raises(ssl.SSLWantReadError): - tssl_client.obj.do_handshake() + data = tutils.Placeholder() assert ( - playbook - >> events.DataReceived(tctx.client, tssl_client.out.read()) - << commands.SendData(tctx.client, data) + playbook + >> reply_tls_start() + << commands.SendData(tctx.client, data) ) tssl_client.inc.write(data()) tssl_client.obj.do_handshake() + interact(playbook, tctx.client, tssl_client) # Both handshakes completed! assert tctx.client.tls_established assert tctx.server.tls_established - - assert tssl_client.obj.selected_alpn_protocol() == "foo" - assert tssl_server.obj.selected_alpn_protocol() == "foo" + _test_echo(playbook, tssl_server, tctx.server) + _test_echo(playbook, tssl_client, tctx.client) diff --git a/test/mitmproxy/proxy2/test_layer.py b/test/mitmproxy/proxy2/test_layer.py index c140c5a61..3a8efc95f 100644 --- a/test/mitmproxy/proxy2/test_layer.py +++ b/test/mitmproxy/proxy2/test_layer.py @@ -5,13 +5,13 @@ from test.mitmproxy.proxy2 import tutils class TestNextLayer: def test_simple(self, tctx): nl = layer.NextLayer(tctx) - playbook = tutils.playbook(nl) + playbook = tutils.playbook(nl, hooks=True) assert ( playbook >> events.DataReceived(tctx.client, b"foo") << commands.Hook("next_layer", nl) - >> events.HookReply(-1) + >> tutils.reply() >> events.DataReceived(tctx.client, b"bar") << commands.Hook("next_layer", nl) ) @@ -21,7 +21,7 @@ class TestNextLayer: nl.layer = tutils.EchoLayer(tctx) assert ( playbook - >> events.HookReply(-1) + >> tutils.reply() << commands.SendData(tctx.client, b"foo") << commands.SendData(tctx.client, b"bar") ) @@ -45,7 +45,7 @@ class TestNextLayer: assert ( playbook - >> events.HookReply(-2) + >> tutils.reply(to=-2) << commands.SendData(tctx.client, b"foo") << commands.SendData(tctx.client, b"bar") ) @@ -63,7 +63,7 @@ class TestNextLayer: handle = nl.handle_event assert ( playbook - >> events.HookReply(-1) + >> tutils.reply() << commands.SendData(tctx.client, b"foo") ) sd, = handle(events.DataReceived(tctx.client, b"bar")) diff --git a/test/mitmproxy/proxy2/test_tutils.py b/test/mitmproxy/proxy2/test_tutils.py index 995976449..29b65171d 100644 --- a/test/mitmproxy/proxy2/test_tutils.py +++ b/test/mitmproxy/proxy2/test_tutils.py @@ -38,7 +38,7 @@ class TLayer(Layer): @pytest.fixture def tplaybook(tctx): - return tutils.playbook(TLayer(tctx), []) + return tutils.playbook(TLayer(tctx), expected=[]) def test_simple(tplaybook): @@ -158,7 +158,7 @@ def test_command_reply(tplaybook): tplaybook >> TEvent() << TCommand() - >> TCommandReply(-1, 42) + >> tutils.reply(42) ) assert tplaybook.actual[1] == tplaybook.actual[2].command diff --git a/test/mitmproxy/proxy2/tutils.py b/test/mitmproxy/proxy2/tutils.py index eab7d35fe..c45440033 100644 --- a/test/mitmproxy/proxy2/tutils.py +++ b/test/mitmproxy/proxy2/tutils.py @@ -2,6 +2,7 @@ import collections.abc import copy import difflib import itertools +import sys import typing from mitmproxy.proxy2 import commands, context @@ -101,7 +102,7 @@ class playbook: def __init__( self, layer: Layer, - hooks: bool = False, + hooks: bool = True, logs: bool = False, expected: typing.Optional[TPlaybook] = None, ): @@ -196,13 +197,13 @@ class playbook: class reply(events.Event): args: typing.Tuple[typing.Any, ...] to: typing.Union[commands.Command, int] - side_effect: typing.Callable[[commands.Command], typing.Any] + side_effect: typing.Callable[[typing.Any], typing.Any] def __init__( self, *args, to: typing.Union[commands.Command, int] = -1, - side_effect: typing.Callable[[commands.Command], typing.Any] = lambda cmd: None + side_effect: typing.Callable[[typing.Any], None] = lambda x: None ): """Utility method to reply to the latest hook in playbooks.""" self.args = args @@ -226,6 +227,7 @@ class reply(events.Event): actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual) raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}") + assert isinstance(self.to, commands.Command) self.side_effect(self.to) reply_cls = command_reply_subclasses[type(self.to)] try: @@ -272,14 +274,16 @@ def Placeholder() -> typing.Any: class EchoLayer(Layer): """Echo layer that sends all data back to the client in lowercase.""" - def _handle_event(self, event: events.Event): + def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: if isinstance(event, events.DataReceived): yield commands.SendData(event.connection, event.data.lower()) def next_layer( - layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]] -) -> events.HookReply: + layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]], + *args, + **kwargs +) -> reply: """ Helper function to simplify the syntax for next_layer events from this: @@ -294,21 +298,10 @@ def next_layer( << commands.Hook("next_layer", next_layer) >> tutils.next_layer(next_layer, tutils.EchoLayer) - >> tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context) """ - raise RuntimeError("Does tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context) work?") - if isinstance(layer, type): - def make_layer(ctx: context.Context) -> Layer: - return layer(ctx) - else: - make_layer = layer - def set_layer(playbook: playbook) -> None: - last_command = playbook.actual[-1] - assert isinstance(last_command, commands.Hook) - assert isinstance(last_command.data, NextLayer) - last_command.data.layer = make_layer(last_command.data.context) + def set_layer(hook: commands.Hook) -> None: + assert isinstance(hook.data, NextLayer) + hook.data.layer = layer(hook.data.context) - reply = events.HookReply(-1) - reply._playbook_eval = set_layer - return reply + return reply(*args, side_effect=set_layer, **kwargs)