diff --git a/test/mitmproxy/proxy2/layers/test_tls.py b/test/mitmproxy/proxy2/layers/test_tls.py index 93303b055..e2018c6d2 100644 --- a/test/mitmproxy/proxy2/layers/test_tls.py +++ b/test/mitmproxy/proxy2/layers/test_tls.py @@ -1,3 +1,4 @@ +import os import ssl import pytest @@ -62,15 +63,35 @@ def test_get_client_hello(): class SSLTest: """Helper container for Python's builtin SSL object.""" - def __init__(self): + + def __init__(self, server_side=False): self.inc = ssl.MemoryBIO() self.out = ssl.MemoryBIO() self.ctx = ssl.SSLContext() - self.obj = self.ctx.wrap_bio(self.inc, self.out, server_side=False) - try: - self.obj.do_handshake() - except ssl.SSLWantReadError: - pass + if server_side: + # FIXME: Replace hardcoded location + self.ctx.load_cert_chain(os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem")) + self.obj = self.ctx.wrap_bio(self.inc, self.out, server_side=server_side) + + +def test_no_tls(tctx: context.Context): + """Test TLS layer without TLS""" + layer = tls.TLSLayer(tctx) + playbook = tutils.playbook(layer) + next_layer = tutils.Placeholder() + + # Handshake + assert ( + playbook + >> events.DataReceived(tctx.client, b"Hello World") + << commands.Hook("next_layer", next_layer) + ) + next_layer().layer = tutils.EchoLayer(next_layer().context) + assert ( + playbook + >> events.HookReply(-1) + << commands.SendData(tctx.client, b"hello world") + ) def test_client_tls(tctx: context.Context): @@ -93,35 +114,100 @@ def test_client_tls(tctx: context.Context): << commands.SendData(tctx.client, data) ) tssl.inc.write(data()) + try: + tssl.obj.do_handshake() + except ssl.SSLWantReadError: + return False + else: + return True - # Send ClientHello, receive ServerHello - interact() + # receive ClientHello, send ServerHello with pytest.raises(ssl.SSLWantReadError): tssl.obj.do_handshake() + assert not interact() # Finish Handshake - interact() + assert interact() tssl.obj.do_handshake() assert layer.state[tctx.client] == tls.ConnectionState.ESTABLISHED assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS # Echo + echo(playbook, tssl, tctx.client) + + +def echo(playbook, tssl, conn): + tconn = type(conn).__name__.lower() tssl.obj.write(b"Hello World") - nextl = tutils.Placeholder() + next_layer = tutils.Placeholder() assert ( playbook - >> events.DataReceived(tctx.client, tssl.out.read()) - << commands.Log( - "PlainDataReceived(client, b'Hello World')") - << commands.Hook("next_layer", nextl) + >> events.DataReceived(conn, tssl.out.read()) + << commands.Log(f"PlainDataReceived({tconn}, b'Hello World')") + << commands.Hook("next_layer", next_layer) ) - nextl().layer = tutils.EchoLayer(nextl().context) + next_layer().layer = tutils.EchoLayer(next_layer().context) data = tutils.Placeholder() assert ( playbook >> events.HookReply(-1) - << commands.Log("PlainSendData(client, b'hello world')") - << commands.SendData(tctx.client, data) + << commands.Log(f"PlainSendData({tconn}, b'hello world')") + << commands.SendData(conn, data) ) tssl.inc.write(data()) assert tssl.obj.read() == b"hello world" + + +def test_server_tls_no_conn(tctx): + layer = tls.TLSLayer(tctx) + playbook = tutils.playbook(layer) + tctx.server.tls = True + + # We did not have a server connection before, so let's do nothing. + assert playbook + assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS + assert layer.state[tctx.server] == tls.ConnectionState.NO_TLS + + +def test_server_tls(tctx): + layer = tls.TLSLayer(tctx) + playbook = tutils.playbook(layer) + tctx.server.connected = True + tctx.server.tls = True + + tssl = SSLTest(server_side=True) + + # send ClientHello + data = tutils.Placeholder() + assert ( + playbook + << commands.SendData(tctx.server, data) + ) + assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS + assert layer.state[tctx.server] == tls.ConnectionState.NEGOTIATING + + # receive ServerHello, finish client handshake + tssl.inc.write(data()) + with pytest.raises(ssl.SSLWantReadError): + tssl.obj.do_handshake() + data = tutils.Placeholder() + assert ( + playbook + >> events.DataReceived(tctx.server, tssl.out.read()) + << commands.SendData(tctx.server, data) + ) + tssl.inc.write(data()) + + # finish server handshake + tssl.obj.do_handshake() + assert ( + playbook + >> events.DataReceived(tctx.server, tssl.out.read()) + << None + ) + + assert layer.state[tctx.client] == tls.ConnectionState.NO_TLS + assert layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED + + # Echo + echo(playbook, tssl, tctx.server)