[sans-io] tls layer++

This commit is contained in:
Maximilian Hils 2019-11-11 18:32:01 +01:00
parent 0c04638d8d
commit 1c80dfe17f
10 changed files with 371 additions and 254 deletions

View File

@ -0,0 +1,138 @@
import os
from typing import Optional, Tuple
from OpenSSL import SSL, crypto
from mitmproxy import certs, ctx, exceptions
from mitmproxy.net import tls as net_tls
from mitmproxy.options import CONF_BASENAME
from mitmproxy.proxy.protocol.tls import CIPHER_ID_NAME_MAP, DEFAULT_CLIENT_CIPHERS
from mitmproxy.proxy2 import context
from mitmproxy.proxy2.layers import tls
def alpn_select_callback(conn: SSL.Connection, options):
server_alpn = conn.get_app_data()["server_alpn"]
if server_alpn and server_alpn in options:
return server_alpn
for alpn in tls.HTTP_ALPNS:
if alpn in options:
return alpn
else:
# FIXME: pyOpenSSL requires that an ALPN is negotiated, we can't return SSL_TLSEXT_ERR_NOACK.
return options[0]
class TlsConfig:
certstore: certs.CertStore
# TODO: We should re-use SSL.Context options here, if only for TLS session resumption.
# This may require patches to pyOpenSSL, as some functionality is only exposed on contexts.
def get_cert(self, context: context.Context) -> Tuple[certs.Cert, SSL.PKey, str]:
return self.certstore.get_cert(
context.client.sni, [context.client.sni]
)
def tls_start(self, tls_start: tls.TlsStart):
if tls_start.conn == tls_start.context.client:
self.create_client_proxy_ssl_conn(tls_start)
else:
self.create_proxy_server_ssl_conn(tls_start)
def create_client_proxy_ssl_conn(self, tls_start: tls.TlsStart) -> None:
tls_method, tls_options = net_tls.VERSION_CHOICES[ctx.options.ssl_version_client]
cert, key, chain_file = self.get_cert(tls_start.context)
if ctx.options.add_upstream_certs_to_client_chain:
raise NotImplementedError()
else:
extra_chain_certs = None
ssl_ctx = net_tls.create_server_context(
cert=cert,
key=key,
method=tls_method,
options=tls_options,
cipher_list=ctx.options.ciphers_client or DEFAULT_CLIENT_CIPHERS,
dhparams=self.certstore.dhparams,
chain_file=chain_file,
alpn_select_callback=alpn_select_callback,
extra_chain_certs=extra_chain_certs,
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
tls_start.ssl_conn.set_app_data({
"server_alpn": tls_start.context.server.alpn
})
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStart) -> None:
client = tls_start.context.client
server: context.Server = tls_start.conn
if server.sni is True:
server.sni = client.sni or server.address[0].encode()
if not server.alpn_offers:
if client.alpn:
server.alpn_offers = [client.alpn]
elif client.alpn_offers:
server.alpn_offers = client.alpn_offers
# We pass through the list of ciphers send by the client, because some HTTP/2 servers
# will select a non-HTTP/2 compatible cipher from our default list and then hang up
# because it's incompatible with h2.
if not server.cipher_list:
if ctx.options.ciphers_server:
server.cipher_list = ctx.options.ciphers_server.split(":")
elif client.cipher_list:
server.cipher_list = [
x for x in client.cipher_list
if x in CIPHER_ID_NAME_MAP
]
args = net_tls.client_arguments_from_options(ctx.options)
client_certs = args.pop("client_certs")
client_cert: Optional[str] = None
if client_certs:
client_certs = os.path.expanduser(client_certs)
if os.path.isfile(client_certs):
client_cert = client_certs
else:
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
path = os.path.join(client_certs, f"{server_name}.pem")
if os.path.exists(path):
client_cert = path
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None
ssl_ctx = net_tls.create_client_context(
cert=client_cert,
sni=server.sni.decode("idna"), # FIXME: Should pass-through here.
alpn_protos=server.alpn_offers,
**args
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
def configure(self, updated):
if not any(x in updated for x in ["confdir", "certs"]):
return
certstore_path = os.path.expanduser(ctx.options.confdir)
if not os.path.exists(os.path.dirname(certstore_path)):
raise exceptions.OptionsError(
f"Certificate Authority parent directory does not exist: {os.path.dirname(certstore_path)}")
self.certstore = certs.CertStore.from_store(
path=certstore_path,
basename=CONF_BASENAME,
key_size=ctx.options.key_size
)
for certspec in ctx.options.certs:
parts = certspec.split("=", 1)
if len(parts) == 1:
parts = ["*", parts[0]]
cert = os.path.expanduser(parts[1])
if not os.path.exists(cert):
raise exceptions.OptionsError(f"Certificate file does not exist: {cert}")
try:
self.certstore.add_cert_file(parts[0], cert)
except crypto.Error as e:
raise exceptions.OptionsError(f"Invalid certificate format: {cert}") from e

View File

@ -21,8 +21,12 @@ class Connection:
tls_established: bool = False tls_established: bool = False
alpn: Optional[bytes] = None alpn: Optional[bytes] = None
alpn_offers: Sequence[bytes] = () alpn_offers: Sequence[bytes] = ()
cipher_list: Sequence[bytes] = ()
tls_version: Optional[str] = None
sni: Union[bytes, bool, None] sni: Union[bytes, bool, None]
timestamp_tls_setup: Optional[float] = None
@property @property
def connected(self): def connected(self):
return self.state is ConnectionState.OPEN return self.state is ConnectionState.OPEN

View File

@ -67,7 +67,7 @@ class Layer:
@abstractmethod @abstractmethod
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
"""Handle a proxy server event""" """Handle a proxy server event"""
yield from () yield from () # pragma: no cover
def handle_event(self, event: events.Event) -> commands.TCommandGenerator: def handle_event(self, event: events.Event) -> commands.TCommandGenerator:
if self._paused: if self._paused:
@ -95,9 +95,6 @@ class Layer:
processing any other commands. processing any other commands.
""" """
try: try:
if isinstance(send, Exception):
command = command_generator.throw(type(send), send)
else:
command = command_generator.send(send) command = command_generator.send(send)
except StopIteration: except StopIteration:
return return

View File

@ -13,7 +13,7 @@ from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.context import Client, Connection, Context, Server from mitmproxy.proxy2.context import Client, Connection, Context, Server
from mitmproxy.proxy2.layer import Layer, NextLayer from mitmproxy.proxy2.layer import Layer, NextLayer
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS
from mitmproxy.proxy2.utils import expect from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human from mitmproxy.utils import human
@ -676,6 +676,9 @@ class HTTPLayer(Layer):
def make_http_connection(self, connection: Server) -> None: def make_http_connection(self, connection: Server) -> None:
if connection.tls and not connection.tls_established: if connection.tls and not connection.tls_established:
connection.alpn_offers = list(HTTP_ALPNS)
if not self.context.options.http2:
connection.alpn_offers.remove(b"h2")
new_command = EstablishServerTLS(connection) new_command = EstablishServerTLS(connection)
new_command.blocking = object() new_command.blocking = object()
yield new_command yield new_command

View File

@ -1,12 +1,10 @@
import os
import struct import struct
import time
from typing import Any, Dict, Generator, Iterator, Optional, Tuple from typing import Any, Dict, Generator, Iterator, Optional, Tuple
from OpenSSL import SSL from OpenSSL import SSL
from mitmproxy.certs import CertStore from mitmproxy.net import tls as net_tls
from mitmproxy.net.tls import ClientHello
from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS
from mitmproxy.proxy2 import commands, events, layer from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2 import context from mitmproxy.proxy2 import context
from mitmproxy.proxy2.utils import expect from mitmproxy.proxy2.utils import expect
@ -69,7 +67,7 @@ def get_client_hello(data: bytes) -> Optional[bytes]:
return None return None
def parse_client_hello(data: bytes) -> Optional[ClientHello]: def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
""" """
Check if the supplied bytes contain a full ClientHello message, Check if the supplied bytes contain a full ClientHello message,
and if so, parse it. and if so, parse it.
@ -84,10 +82,13 @@ def parse_client_hello(data: bytes) -> Optional[ClientHello]:
# Check if ClientHello is complete # Check if ClientHello is complete
client_hello = get_client_hello(data) client_hello = get_client_hello(data)
if client_hello: if client_hello:
return ClientHello(client_hello[4:]) return net_tls.ClientHello(client_hello[4:])
return None return None
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
class EstablishServerTLS(commands.ConnectionCommand): class EstablishServerTLS(commands.ConnectionCommand):
connection: context.Server connection: context.Server
blocking = True blocking = True
@ -99,9 +100,17 @@ class EstablishServerTLSReply(events.CommandReply):
"""error message""" """error message"""
class TlsStart:
def __init__(self, conn: context.Connection, context: context.Context) -> None:
self.conn = conn
self.context = context
self.ssl_conn = None
class _TLSLayer(layer.Layer): class _TLSLayer(layer.Layer):
tls: Dict[context.Connection, SSL.Connection] tls: Dict[context.Connection, SSL.Connection]
child_layer: layer.Layer child_layer: layer.Layer
ssl_context: Optional[SSL.Context] = None
def __init__(self, context: context.Context): def __init__(self, context: context.Context):
super().__init__(context) super().__init__(context)
@ -140,15 +149,18 @@ class _TLSLayer(layer.Layer):
except SSL.WantReadError: except SSL.WantReadError:
yield from self.tls_interact(conn) yield from self.tls_interact(conn)
return False, None return False, None
except SSL.ZeroReturnError as e: except SSL.Error as e:
return False, repr(e) return False, repr(e)
else: else:
conn.tls_established = True conn.tls_established = True
conn.sni = self.tls[conn].get_servername()
conn.alpn = self.tls[conn].get_alpn_proto_negotiated() conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
conn.cipher_list = self.tls[conn].get_cipher_list()
conn.tls_version = self.tls[conn].get_protocol_version_name()
conn.timestamp_tls_setup = time.time()
yield commands.Log(f"TLS established: {conn}") yield commands.Log(f"TLS established: {conn}")
yield from self.receive(conn, b"") yield from self.receive(conn, b"")
# TODO: Set all other connection attributes here # TODO: Set all other connection attributes here
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
return True, None return True, None
def receive(self, conn: context.Connection, data: bytes): def receive(self, conn: context.Connection, data: bytes):
@ -213,8 +225,8 @@ class ServerTLSLayer(_TLSLayer):
self.command_to_reply_to = {} self.command_to_reply_to = {}
self.child_layer = layer.NextLayer(self.context) self.child_layer = layer.NextLayer(self.context)
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[ def negotiate(self, conn: context.Connection, data: bytes) \
commands.Command, Any, Tuple[bool, Optional[str]]]: -> Generator[commands.Command, Any, Tuple[bool, Optional[str]]]:
done, err = yield from super().negotiate(conn, data) done, err = yield from super().negotiate(conn, data)
if done or err: if done or err:
cmd = self.command_to_reply_to.pop(conn) cmd = self.command_to_reply_to.pop(conn)
@ -232,19 +244,11 @@ class ServerTLSLayer(_TLSLayer):
def start_server_tls(self, conn: context.Server): def start_server_tls(self, conn: context.Server):
assert conn not in self.tls assert conn not in self.tls
assert conn.connected assert conn.connected
conn.tls = True
ssl_context = SSL.Context(SSL.SSLv23_METHOD) tls_start = TlsStart(conn, self.context)
if conn.alpn_offers: yield commands.Hook("tls_start", tls_start)
ssl_context.set_alpn_protos(conn.alpn_offers) self.tls[conn] = tls_start.ssl_conn
self.tls[conn] = SSL.Connection(ssl_context)
if conn.sni:
if conn.sni is True:
if self.context.client.sni:
conn.sni = self.context.client.sni
else:
conn.sni = conn.address[0].encode()
self.tls[conn].set_tlsext_host_name(conn.sni)
self.tls[conn].set_connect_state() self.tls[conn].set_connect_state()
yield from self.negotiate(conn, b"") yield from self.negotiate(conn, b"")
@ -274,6 +278,7 @@ class ClientTLSLayer(_TLSLayer):
super().__init__(context) super().__init__(context)
self.recv_buffer = bytearray() self.recv_buffer = bytearray()
self.child_layer = layer.NextLayer(self.context) self.child_layer = layer.NextLayer(self.context)
self._handle_event = self.state_start
@expect(events.Start) @expect(events.Start)
def state_start(self, _) -> commands.TCommandGenerator: def state_start(self, _) -> commands.TCommandGenerator:
@ -281,9 +286,6 @@ class ClientTLSLayer(_TLSLayer):
self._handle_event = self.state_wait_for_clienthello self._handle_event = self.state_wait_for_clienthello
yield from () yield from ()
_handle_event = state_start
@expect(events.DataReceived, events.ConnectionClosed)
def state_wait_for_clienthello(self, event: events.Event): def state_wait_for_clienthello(self, event: events.Event):
client = self.context.client client = self.context.client
if isinstance(event, events.DataReceived) and event.connection == client: if isinstance(event, events.DataReceived) and event.connection == client:
@ -296,8 +298,7 @@ class ClientTLSLayer(_TLSLayer):
if client_hello: if client_hello:
yield commands.Log(f"Client Hello: {client_hello}") yield commands.Log(f"Client Hello: {client_hello}")
# TODO: Don't do double conversion client.sni = client_hello.sni
client.sni = client_hello.sni.encode("idna")
client.alpn_offers = client_hello.alpn_protocols client.alpn_offers = client_hello.alpn_protocols
client_tls_requires_server_connection = ( client_tls_requires_server_connection = (
@ -322,8 +323,10 @@ class ClientTLSLayer(_TLSLayer):
# In any case, we now have enough information to start server TLS if needed. # In any case, we now have enough information to start server TLS if needed.
yield from self.event_to_child(events.Start()) yield from self.event_to_child(events.Start())
elif isinstance(event, events.ConnectionClosed) and event.connection == client:
self.recv_buffer.clear()
else: else:
raise NotImplementedError(event) # TODO yield from self.event_to_child(event)
def start_server_tls(self): def start_server_tls(self):
""" """
@ -339,11 +342,6 @@ class ClientTLSLayer(_TLSLayer):
) )
return err return err
server.alpn_offers = [
x for x in self.context.client.alpn_offers
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
]
err = yield EstablishServerTLS(server) err = yield EstablishServerTLS(server)
if err: if err:
yield commands.Log( yield commands.Log(
@ -352,36 +350,10 @@ class ClientTLSLayer(_TLSLayer):
return err return err
def start_client_tls(self) -> commands.TCommandGenerator: def start_client_tls(self) -> commands.TCommandGenerator:
# FIXME: Do this properly. Also adjust error message in negotiate()
client = self.context.client client = self.context.client
server = self.context.server tls_start = TlsStart(client, self.context)
context = SSL.Context(SSL.SSLv23_METHOD) yield commands.Hook("tls_start", tls_start)
cert, privkey, cert_chain = CertStore.from_store( self.tls[client] = tls_start.ssl_conn
os.path.expanduser("~/.mitmproxy"), "mitmproxy",
self.context.options.key_size
).get_cert(client.sni, (client.sni,))
context.use_privatekey(privkey)
context.use_certificate(cert.x509)
context.set_cipher_list(DEFAULT_CLIENT_CIPHERS)
def alpn_select_callback(conn_, options):
if server.alpn in options:
return server.alpn
elif b"h2" in options:
return b"h2"
elif b"http/1.1" in options:
return b"http/1.1"
elif b"http/1.0" in options:
return b"http/1.0"
elif b"http/0.9" in options:
return b"http/0.9"
else:
# FIXME: We MUST return something here. At this point we are at loss.
return options[0]
context.set_alpn_select_callback(alpn_select_callback)
self.tls[client] = SSL.Connection(context)
self.tls[client].set_accept_state() self.tls[client].set_accept_state()
yield from self.negotiate(client, bytes(self.recv_buffer)) yield from self.negotiate(client, bytes(self.recv_buffer))
@ -390,11 +362,16 @@ class ClientTLSLayer(_TLSLayer):
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]: def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
done, err = yield from super().negotiate(conn, data) done, err = yield from super().negotiate(conn, data)
if err: if err:
if self.context.client.sni:
# TODO: Also use other sources than SNI
dest = " for " + self.context.client.sni.decode("idna")
else:
dest = ""
yield commands.Log( yield commands.Log(
f"Client TLS Handshake failed. " f"Client TLS Handshake failed. "
f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).", f"The client may not trust the proxy's certificate{dest} ({err}).",
level="warn" level="warn"
# TODO: Also use other sources than SNI
) )
yield commands.CloseConnection(self.context.client) yield commands.CloseConnection(self.context.client)
return done return done

View File

@ -24,7 +24,7 @@ def test_open_connection(tctx):
def test_open_connection_err(tctx): def test_open_connection_err(tctx):
f = Placeholder() f = Placeholder()
assert ( assert (
playbook(TCPLayer(tctx), hooks=True) playbook(TCPLayer(tctx))
<< Hook("tcp_start", f) << Hook("tcp_start", f)
>> reply() >> reply()
<< OpenConnection(tctx.server) << OpenConnection(tctx.server)
@ -40,7 +40,7 @@ def test_simple(tctx):
f = Placeholder() f = Placeholder()
assert ( assert (
playbook(TCPLayer(tctx), hooks=True) playbook(TCPLayer(tctx))
<< Hook("tcp_start", f) << Hook("tcp_start", f)
>> reply() >> reply()
<< OpenConnection(tctx.server) << OpenConnection(tctx.server)
@ -71,7 +71,7 @@ def test_receive_data_before_server_connected(tctx):
will still be forwarded. will still be forwarded.
""" """
assert ( assert (
playbook(TCPLayer(tctx)) playbook(TCPLayer(tctx), hooks=False)
<< OpenConnection(tctx.server) << OpenConnection(tctx.server)
>> DataReceived(tctx.client, b"hello!") >> DataReceived(tctx.client, b"hello!")
>> reply(None, to=-2) >> reply(None, to=-2)

View File

@ -3,11 +3,15 @@ import ssl
import typing import typing
import pytest import pytest
from OpenSSL import SSL
from mitmproxy.proxy2 import context, events, commands from mitmproxy.proxy2 import commands, context, events
from mitmproxy.proxy2.layers import tls from mitmproxy.proxy2.layers import tls
from mitmproxy.utils import data
from test.mitmproxy.proxy2 import tutils from test.mitmproxy.proxy2 import tutils
tlsdata = data.Data(__name__)
def test_is_tls_handshake_record(): def test_is_tls_handshake_record():
assert tls.is_tls_handshake_record(bytes.fromhex("160300")) assert tls.is_tls_handshake_record(bytes.fromhex("160300"))
@ -33,11 +37,11 @@ def test_record_contents():
def test_record_contents_err(): def test_record_contents_err():
with pytest.raises(ValueError, msg="Expected TLS record"): with pytest.raises(ValueError, match="Expected TLS record"):
next(tls.handshake_record_contents(b"GET /error")) next(tls.handshake_record_contents(b"GET /error"))
empty_record = bytes.fromhex("1603010000") empty_record = bytes.fromhex("1603010000")
with pytest.raises(ValueError, msg="Record must not be empty"): with pytest.raises(ValueError, match="Record must not be empty"):
next(tls.handshake_record_contents(empty_record)) next(tls.handshake_record_contents(empty_record))
@ -65,7 +69,8 @@ def test_get_client_hello():
class SSLTest: class SSLTest:
"""Helper container for Python's builtin SSL object.""" """Helper container for Python's builtin SSL object."""
def __init__(self, server_side=False, alpn=None): def __init__(self, server_side: bool = False, alpn: typing.List[bytes] = None,
sni: typing.Optional[bytes] = b"example.com"):
self.inc = ssl.MemoryBIO() self.inc = ssl.MemoryBIO()
self.out = ssl.MemoryBIO() self.out = ssl.MemoryBIO()
self.ctx = ssl.SSLContext() self.ctx = ssl.SSLContext()
@ -77,83 +82,78 @@ class SSLTest:
self.obj = self.ctx.wrap_bio( self.obj = self.ctx.wrap_bio(
self.inc, self.inc,
self.out, self.out,
server_hostname=None if server_side else "example.com", server_hostname=None if server_side else sni,
server_side=server_side, server_side=server_side,
) )
def _test_tls_client_server( def _test_echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
tctx: context.Context,
alpn: typing.Optional[str]
) -> typing.Tuple[tutils.playbook[tls.ClientTLSLayer], SSLTest]:
layer = tls.ClientTLSLayer(tctx)
playbook = tutils.playbook(layer)
tctx.server.tls = True
tctx.server.address = ("example.com", 443)
tssl_client = SSLTest(alpn=alpn)
# Handshake
assert (
playbook
<< None
)
with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake()
client_hello = tssl_client.out.read()
assert (
playbook
>> events.DataReceived(tctx.client, client_hello[:42])
<< None
)
# Still waiting...
# Finish sending ClientHello
playbook >> events.DataReceived(tctx.client, client_hello[42:])
return playbook, tssl_client
def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
tssl.obj.write(b"Hello World") tssl.obj.write(b"Hello World")
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(conn, tssl.out.read()) >> events.DataReceived(conn, tssl.out.read())
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(tutils.EchoLayer)
<< commands.SendData(conn, data) << commands.SendData(conn, data)
) )
tssl.inc.write(data()) tssl.inc.write(data())
assert tssl.obj.read() == b"hello world" assert tssl.obj.read() == b"hello world"
class TlsEchoLayer(tutils.EchoLayer):
err: typing.Optional[str] = None
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls":
# noinspection PyTypeChecker
self.err = yield tls.EstablishServerTLS(self.context.server)
else:
yield from super()._handle_event(event)
def interact(playbook: tutils.playbook, conn: context.Connection, tssl: SSLTest):
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(conn, tssl.out.read())
<< commands.SendData(conn, data)
)
tssl.inc.write(data())
def reply_tls_start(*args, **kwargs) -> tutils.reply:
"""
Helper function to simplify the syntax for tls_start hooks.
"""
def make_conn(hook: commands.Hook) -> None:
tls_start = hook.data
assert isinstance(tls_start, tls.TlsStart)
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
if tls_start.conn == tls_start.context.client:
ssl_context.use_privatekey_file(
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.key")
)
ssl_context.use_certificate_chain_file(
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.crt")
)
tls_start.ssl_conn = SSL.Connection(ssl_context)
return tutils.reply(*args, side_effect=make_conn, **kwargs)
class TestServerTLS: class TestServerTLS:
def test_no_tls(self, tctx: context.Context): def test_no_tls(self, tctx: context.Context):
"""Test TLS layer without TLS""" """Test TLS layer without TLS"""
layer = tls.ServerTLSLayer(tctx) layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer) layer.child_layer = TlsEchoLayer(tctx)
# Handshake # Handshake
assert ( assert (
playbook tutils.playbook(layer)
>> events.DataReceived(tctx.client, b"Hello World") >> events.DataReceived(tctx.client, b"Hello World")
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(tutils.EchoLayer)
<< commands.SendData(tctx.client, b"hello world") << commands.SendData(tctx.client, b"hello world")
) >> events.DataReceived(tctx.server, b"Foo")
<< commands.SendData(tctx.server, b"foo")
def test_no_connection(self, tctx):
"""
The server TLS layer is initiated, but there is no active connection yet, so nothing
should be done.
"""
layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer)
tctx.server.tls = True
# We did not have a server connection before, so let's do nothing.
assert (
playbook
<< None
) )
def test_simple(self, tctx): def test_simple(self, tctx):
@ -161,7 +161,6 @@ class TestServerTLS:
playbook = tutils.playbook(layer) playbook = tutils.playbook(layer)
tctx.server.connected = True tctx.server.connected = True
tctx.server.address = ("example.com", 443) tctx.server.address = ("example.com", 443)
tctx.server.tls = True
tssl = SSLTest(server_side=True) tssl = SSLTest(server_side=True)
@ -169,6 +168,11 @@ class TestServerTLS:
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, b"establish-server-tls")
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(TlsEchoLayer)
<< commands.Hook("tls_start", tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.server, data) << commands.SendData(tctx.server, data)
) )
@ -176,13 +180,7 @@ class TestServerTLS:
tssl.inc.write(data()) tssl.inc.write(data())
with pytest.raises(ssl.SSLWantReadError): with pytest.raises(ssl.SSLWantReadError):
tssl.obj.do_handshake() tssl.obj.do_handshake()
data = tutils.Placeholder() interact(playbook, tctx.server, tssl)
assert (
playbook
>> events.DataReceived(tctx.server, tssl.out.read())
<< commands.SendData(tctx.server, data)
)
tssl.inc.write(data())
# finish server handshake # finish server handshake
tssl.obj.do_handshake() tssl.obj.do_handshake()
@ -193,134 +191,141 @@ class TestServerTLS:
) )
assert tctx.server.tls_established assert tctx.server.tls_established
assert tctx.server.sni == b"example.com"
# Echo # Echo
echo(playbook, tssl, tctx.server) _test_echo(playbook, tssl, tctx.server)
def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer]:
# This is a bit contrived as the client layer expects a server layer as parent.
# We also set child layers manually to avoid NextLayer noise.
server_layer = tls.ServerTLSLayer(tctx)
client_layer = tls.ClientTLSLayer(tctx)
server_layer.child_layer = client_layer
client_layer.child_layer = TlsEchoLayer(tctx)
playbook = tutils.playbook(server_layer)
return playbook, client_layer
def _test_tls_client_server(
tctx: context.Context,
sni: typing.Optional[bytes]
) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer, SSLTest]:
playbook, client_layer = _make_client_tls_layer(tctx)
tctx.server.tls = True
tctx.server.address = ("example.com", 443)
tssl_client = SSLTest(sni=sni)
# Send ClientHello
with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake()
return playbook, client_layer, tssl_client
class TestClientTLS: class TestClientTLS:
def test_simple(self, tctx: context.Context): def test_client_only(self, tctx: context.Context):
"""Test TLS with client only""" """Test TLS with client only"""
layer = tls.ClientTLSLayer(tctx) playbook, client_layer = _make_client_tls_layer(tctx)
playbook = tutils.playbook(layer)
tssl = SSLTest() tssl = SSLTest()
assert not tctx.client.tls_established
# Handshake # Start Handshake, send ClientHello and ServerHello
assert playbook with pytest.raises(ssl.SSLWantReadError):
assert layer._handle_event == layer.state_wait_for_clienthello tssl.obj.do_handshake()
def interact():
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, tssl.out.read()) >> events.DataReceived(tctx.client, tssl.out.read())
<< commands.Hook("tls_start", tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.client, data) << commands.SendData(tctx.client, data)
) )
tssl.inc.write(data()) tssl.inc.write(data())
try:
tssl.obj.do_handshake() tssl.obj.do_handshake()
except ssl.SSLWantReadError:
return False
else:
return True
# receive ClientHello, send ServerHello
with pytest.raises(ssl.SSLWantReadError):
tssl.obj.do_handshake()
assert not interact()
# Finish Handshake # Finish Handshake
assert interact() interact(playbook, tctx.client, tssl)
tssl.obj.do_handshake()
assert layer._handle_event == layer.state_process assert tssl.obj.getpeercert(True)
assert tctx.client.tls_established
# Echo # Echo
echo(playbook, tssl, tctx.client) _test_echo(playbook, tssl, tctx.client)
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.server, b"Hello") >> events.DataReceived(tctx.server, b"Plaintext")
<< commands.SendData(tctx.server, b"hello") << commands.SendData(tctx.server, b"plaintext")
) )
def test_no_server_conn_required(self, tctx): def test_server_not_required(self, tctx):
""" """
Here we test the scenario where a server connection is _not_ required Here we test the scenario where a server connection is _not_ required
to establish TLS with the client. After determining this when parsing the ClientHello, to establish TLS with the client. After determining this when parsing the ClientHello,
we only establish a connection with the client. The server connection may ultimately we only establish a connection with the client. The server connection may ultimately
be established when OpenConnection is called. be established when OpenConnection is called.
""" """
playbook, _ = _test_tls_client_server(tctx, None) playbook, client_layer, tssl = _test_tls_client_server(tctx, sni=b"example.com")
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, tssl.out.read())
<< commands.Hook("tls_start", tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.client, data) << commands.SendData(tctx.client, data)
) )
assert data() tssl.inc.write(data())
assert playbook.layer._handle_event == playbook.layer.state_process tssl.obj.do_handshake()
interact(playbook, tctx.client, tssl)
assert tctx.client.tls_established
def test_alpn(self, tctx): def test_server_required(self, tctx):
""" """
Here we test the scenario where a server connection is required (e.g. because of ALPN negotation) Here we test the scenario where a server connection is required (because SNI is missing)
to establish TLS with the client. to establish TLS with the client.
""" """
tssl_server = SSLTest(server_side=True, alpn=["foo", "bar"]) tssl_server = SSLTest(server_side=True)
playbook, client_layer, tssl_client = _test_tls_client_server(tctx, sni=None)
playbook, tssl_client = _test_tls_client_server(tctx, ["qux", "foo"])
# We should now get instructed to open a server connection. # We should now get instructed to open a server connection.
assert (
playbook
<< commands.OpenConnection(tctx.server)
)
tctx.server.connected = True
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.OpenConnectionReply(-1, None) >> events.DataReceived(tctx.client, tssl_client.out.read())
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
<< commands.Hook("tls_start", tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.server, data) << commands.SendData(tctx.server, data)
) )
assert playbook.layer._handle_event == playbook.layer.state_wait_for_server_tls
assert playbook.layer.child_layer.tls[tctx.server]
# Establish TLS with the server... # Establish TLS with the server...
tssl_server.inc.write(data()) tssl_server.inc.write(data())
with pytest.raises(ssl.SSLWantReadError): with pytest.raises(ssl.SSLWantReadError):
tssl_server.obj.do_handshake() tssl_server.obj.do_handshake()
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.server, tssl_server.out.read()) >> events.DataReceived(tctx.server, tssl_server.out.read())
<< commands.SendData(tctx.server, data) << commands.SendData(tctx.server, data)
<< commands.Hook("tls_start", tutils.Placeholder())
) )
tssl_server.inc.write(data()) tssl_server.inc.write(data())
tssl_server.obj.do_handshake()
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.server, tssl_server.out.read())
<< commands.SendData(tctx.client, data)
)
assert playbook.layer._handle_event == playbook.layer.state_process
assert tctx.server.tls_established assert tctx.server.tls_established
# Server TLS is established, we can now reply to the client handshake... # Server TLS is established, we can now reply to the client handshake...
tssl_client.inc.write(data())
with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake()
data = tutils.Placeholder() data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, tssl_client.out.read()) >> reply_tls_start()
<< commands.SendData(tctx.client, data) << commands.SendData(tctx.client, data)
) )
tssl_client.inc.write(data()) tssl_client.inc.write(data())
tssl_client.obj.do_handshake() tssl_client.obj.do_handshake()
interact(playbook, tctx.client, tssl_client)
# Both handshakes completed! # Both handshakes completed!
assert tctx.client.tls_established assert tctx.client.tls_established
assert tctx.server.tls_established assert tctx.server.tls_established
_test_echo(playbook, tssl_server, tctx.server)
assert tssl_client.obj.selected_alpn_protocol() == "foo" _test_echo(playbook, tssl_client, tctx.client)
assert tssl_server.obj.selected_alpn_protocol() == "foo"

View File

@ -5,13 +5,13 @@ from test.mitmproxy.proxy2 import tutils
class TestNextLayer: class TestNextLayer:
def test_simple(self, tctx): def test_simple(self, tctx):
nl = layer.NextLayer(tctx) nl = layer.NextLayer(tctx)
playbook = tutils.playbook(nl) playbook = tutils.playbook(nl, hooks=True)
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, b"foo") >> events.DataReceived(tctx.client, b"foo")
<< commands.Hook("next_layer", nl) << commands.Hook("next_layer", nl)
>> events.HookReply(-1) >> tutils.reply()
>> events.DataReceived(tctx.client, b"bar") >> events.DataReceived(tctx.client, b"bar")
<< commands.Hook("next_layer", nl) << commands.Hook("next_layer", nl)
) )
@ -21,7 +21,7 @@ class TestNextLayer:
nl.layer = tutils.EchoLayer(tctx) nl.layer = tutils.EchoLayer(tctx)
assert ( assert (
playbook playbook
>> events.HookReply(-1) >> tutils.reply()
<< commands.SendData(tctx.client, b"foo") << commands.SendData(tctx.client, b"foo")
<< commands.SendData(tctx.client, b"bar") << commands.SendData(tctx.client, b"bar")
) )
@ -45,7 +45,7 @@ class TestNextLayer:
assert ( assert (
playbook playbook
>> events.HookReply(-2) >> tutils.reply(to=-2)
<< commands.SendData(tctx.client, b"foo") << commands.SendData(tctx.client, b"foo")
<< commands.SendData(tctx.client, b"bar") << commands.SendData(tctx.client, b"bar")
) )
@ -63,7 +63,7 @@ class TestNextLayer:
handle = nl.handle_event handle = nl.handle_event
assert ( assert (
playbook playbook
>> events.HookReply(-1) >> tutils.reply()
<< commands.SendData(tctx.client, b"foo") << commands.SendData(tctx.client, b"foo")
) )
sd, = handle(events.DataReceived(tctx.client, b"bar")) sd, = handle(events.DataReceived(tctx.client, b"bar"))

View File

@ -38,7 +38,7 @@ class TLayer(Layer):
@pytest.fixture @pytest.fixture
def tplaybook(tctx): def tplaybook(tctx):
return tutils.playbook(TLayer(tctx), []) return tutils.playbook(TLayer(tctx), expected=[])
def test_simple(tplaybook): def test_simple(tplaybook):
@ -158,7 +158,7 @@ def test_command_reply(tplaybook):
tplaybook tplaybook
>> TEvent() >> TEvent()
<< TCommand() << TCommand()
>> TCommandReply(-1, 42) >> tutils.reply(42)
) )
assert tplaybook.actual[1] == tplaybook.actual[2].command assert tplaybook.actual[1] == tplaybook.actual[2].command

View File

@ -2,6 +2,7 @@ import collections.abc
import copy import copy
import difflib import difflib
import itertools import itertools
import sys
import typing import typing
from mitmproxy.proxy2 import commands, context from mitmproxy.proxy2 import commands, context
@ -101,7 +102,7 @@ class playbook:
def __init__( def __init__(
self, self,
layer: Layer, layer: Layer,
hooks: bool = False, hooks: bool = True,
logs: bool = False, logs: bool = False,
expected: typing.Optional[TPlaybook] = None, expected: typing.Optional[TPlaybook] = None,
): ):
@ -196,13 +197,13 @@ class playbook:
class reply(events.Event): class reply(events.Event):
args: typing.Tuple[typing.Any, ...] args: typing.Tuple[typing.Any, ...]
to: typing.Union[commands.Command, int] to: typing.Union[commands.Command, int]
side_effect: typing.Callable[[commands.Command], typing.Any] side_effect: typing.Callable[[typing.Any], typing.Any]
def __init__( def __init__(
self, self,
*args, *args,
to: typing.Union[commands.Command, int] = -1, to: typing.Union[commands.Command, int] = -1,
side_effect: typing.Callable[[commands.Command], typing.Any] = lambda cmd: None side_effect: typing.Callable[[typing.Any], None] = lambda x: None
): ):
"""Utility method to reply to the latest hook in playbooks.""" """Utility method to reply to the latest hook in playbooks."""
self.args = args self.args = args
@ -226,6 +227,7 @@ class reply(events.Event):
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual) actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}") raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}")
assert isinstance(self.to, commands.Command)
self.side_effect(self.to) self.side_effect(self.to)
reply_cls = command_reply_subclasses[type(self.to)] reply_cls = command_reply_subclasses[type(self.to)]
try: try:
@ -272,14 +274,16 @@ def Placeholder() -> typing.Any:
class EchoLayer(Layer): class EchoLayer(Layer):
"""Echo layer that sends all data back to the client in lowercase.""" """Echo layer that sends all data back to the client in lowercase."""
def _handle_event(self, event: events.Event): def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.DataReceived): if isinstance(event, events.DataReceived):
yield commands.SendData(event.connection, event.data.lower()) yield commands.SendData(event.connection, event.data.lower())
def next_layer( def next_layer(
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]] layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]],
) -> events.HookReply: *args,
**kwargs
) -> reply:
""" """
Helper function to simplify the syntax for next_layer events from this: Helper function to simplify the syntax for next_layer events from this:
@ -294,21 +298,10 @@ def next_layer(
<< commands.Hook("next_layer", next_layer) << commands.Hook("next_layer", next_layer)
>> tutils.next_layer(next_layer, tutils.EchoLayer) >> tutils.next_layer(next_layer, tutils.EchoLayer)
>> tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context)
""" """
raise RuntimeError("Does tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context) work?")
if isinstance(layer, type):
def make_layer(ctx: context.Context) -> Layer:
return layer(ctx)
else:
make_layer = layer
def set_layer(playbook: playbook) -> None: def set_layer(hook: commands.Hook) -> None:
last_command = playbook.actual[-1] assert isinstance(hook.data, NextLayer)
assert isinstance(last_command, commands.Hook) hook.data.layer = layer(hook.data.context)
assert isinstance(last_command.data, NextLayer)
last_command.data.layer = make_layer(last_command.data.context)
reply = events.HookReply(-1) return reply(*args, side_effect=set_layer, **kwargs)
reply._playbook_eval = set_layer
return reply