[sans-io] tls layer++

This commit is contained in:
Maximilian Hils 2019-11-11 18:32:01 +01:00
parent 0c04638d8d
commit 1c80dfe17f
10 changed files with 371 additions and 254 deletions

View File

@ -0,0 +1,138 @@
import os
from typing import Optional, Tuple
from OpenSSL import SSL, crypto
from mitmproxy import certs, ctx, exceptions
from mitmproxy.net import tls as net_tls
from mitmproxy.options import CONF_BASENAME
from mitmproxy.proxy.protocol.tls import CIPHER_ID_NAME_MAP, DEFAULT_CLIENT_CIPHERS
from mitmproxy.proxy2 import context
from mitmproxy.proxy2.layers import tls
def alpn_select_callback(conn: SSL.Connection, options):
server_alpn = conn.get_app_data()["server_alpn"]
if server_alpn and server_alpn in options:
return server_alpn
for alpn in tls.HTTP_ALPNS:
if alpn in options:
return alpn
else:
# FIXME: pyOpenSSL requires that an ALPN is negotiated, we can't return SSL_TLSEXT_ERR_NOACK.
return options[0]
class TlsConfig:
certstore: certs.CertStore
# TODO: We should re-use SSL.Context options here, if only for TLS session resumption.
# This may require patches to pyOpenSSL, as some functionality is only exposed on contexts.
def get_cert(self, context: context.Context) -> Tuple[certs.Cert, SSL.PKey, str]:
return self.certstore.get_cert(
context.client.sni, [context.client.sni]
)
def tls_start(self, tls_start: tls.TlsStart):
if tls_start.conn == tls_start.context.client:
self.create_client_proxy_ssl_conn(tls_start)
else:
self.create_proxy_server_ssl_conn(tls_start)
def create_client_proxy_ssl_conn(self, tls_start: tls.TlsStart) -> None:
tls_method, tls_options = net_tls.VERSION_CHOICES[ctx.options.ssl_version_client]
cert, key, chain_file = self.get_cert(tls_start.context)
if ctx.options.add_upstream_certs_to_client_chain:
raise NotImplementedError()
else:
extra_chain_certs = None
ssl_ctx = net_tls.create_server_context(
cert=cert,
key=key,
method=tls_method,
options=tls_options,
cipher_list=ctx.options.ciphers_client or DEFAULT_CLIENT_CIPHERS,
dhparams=self.certstore.dhparams,
chain_file=chain_file,
alpn_select_callback=alpn_select_callback,
extra_chain_certs=extra_chain_certs,
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
tls_start.ssl_conn.set_app_data({
"server_alpn": tls_start.context.server.alpn
})
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStart) -> None:
client = tls_start.context.client
server: context.Server = tls_start.conn
if server.sni is True:
server.sni = client.sni or server.address[0].encode()
if not server.alpn_offers:
if client.alpn:
server.alpn_offers = [client.alpn]
elif client.alpn_offers:
server.alpn_offers = client.alpn_offers
# We pass through the list of ciphers send by the client, because some HTTP/2 servers
# will select a non-HTTP/2 compatible cipher from our default list and then hang up
# because it's incompatible with h2.
if not server.cipher_list:
if ctx.options.ciphers_server:
server.cipher_list = ctx.options.ciphers_server.split(":")
elif client.cipher_list:
server.cipher_list = [
x for x in client.cipher_list
if x in CIPHER_ID_NAME_MAP
]
args = net_tls.client_arguments_from_options(ctx.options)
client_certs = args.pop("client_certs")
client_cert: Optional[str] = None
if client_certs:
client_certs = os.path.expanduser(client_certs)
if os.path.isfile(client_certs):
client_cert = client_certs
else:
server_name: str = (server.sni or server.address[0].encode("idna")).decode()
path = os.path.join(client_certs, f"{server_name}.pem")
if os.path.exists(path):
client_cert = path
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None
ssl_ctx = net_tls.create_client_context(
cert=client_cert,
sni=server.sni.decode("idna"), # FIXME: Should pass-through here.
alpn_protos=server.alpn_offers,
**args
)
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
def configure(self, updated):
if not any(x in updated for x in ["confdir", "certs"]):
return
certstore_path = os.path.expanduser(ctx.options.confdir)
if not os.path.exists(os.path.dirname(certstore_path)):
raise exceptions.OptionsError(
f"Certificate Authority parent directory does not exist: {os.path.dirname(certstore_path)}")
self.certstore = certs.CertStore.from_store(
path=certstore_path,
basename=CONF_BASENAME,
key_size=ctx.options.key_size
)
for certspec in ctx.options.certs:
parts = certspec.split("=", 1)
if len(parts) == 1:
parts = ["*", parts[0]]
cert = os.path.expanduser(parts[1])
if not os.path.exists(cert):
raise exceptions.OptionsError(f"Certificate file does not exist: {cert}")
try:
self.certstore.add_cert_file(parts[0], cert)
except crypto.Error as e:
raise exceptions.OptionsError(f"Invalid certificate format: {cert}") from e

View File

@ -21,8 +21,12 @@ class Connection:
tls_established: bool = False
alpn: Optional[bytes] = None
alpn_offers: Sequence[bytes] = ()
cipher_list: Sequence[bytes] = ()
tls_version: Optional[str] = None
sni: Union[bytes, bool, None]
timestamp_tls_setup: Optional[float] = None
@property
def connected(self):
return self.state is ConnectionState.OPEN

View File

@ -67,7 +67,7 @@ class Layer:
@abstractmethod
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
"""Handle a proxy server event"""
yield from ()
yield from () # pragma: no cover
def handle_event(self, event: events.Event) -> commands.TCommandGenerator:
if self._paused:
@ -95,9 +95,6 @@ class Layer:
processing any other commands.
"""
try:
if isinstance(send, Exception):
command = command_generator.throw(type(send), send)
else:
command = command_generator.send(send)
except StopIteration:
return

View File

@ -13,7 +13,7 @@ from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.context import Client, Connection, Context, Server
from mitmproxy.proxy2.layer import Layer, NextLayer
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply
from mitmproxy.proxy2.layers.tls import EstablishServerTLS, EstablishServerTLSReply, HTTP_ALPNS
from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human
@ -676,6 +676,9 @@ class HTTPLayer(Layer):
def make_http_connection(self, connection: Server) -> None:
if connection.tls and not connection.tls_established:
connection.alpn_offers = list(HTTP_ALPNS)
if not self.context.options.http2:
connection.alpn_offers.remove(b"h2")
new_command = EstablishServerTLS(connection)
new_command.blocking = object()
yield new_command

View File

@ -1,12 +1,10 @@
import os
import struct
import time
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
from OpenSSL import SSL
from mitmproxy.certs import CertStore
from mitmproxy.net.tls import ClientHello
from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS
from mitmproxy.net import tls as net_tls
from mitmproxy.proxy2 import commands, events, layer
from mitmproxy.proxy2 import context
from mitmproxy.proxy2.utils import expect
@ -69,7 +67,7 @@ def get_client_hello(data: bytes) -> Optional[bytes]:
return None
def parse_client_hello(data: bytes) -> Optional[ClientHello]:
def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
"""
Check if the supplied bytes contain a full ClientHello message,
and if so, parse it.
@ -84,10 +82,13 @@ def parse_client_hello(data: bytes) -> Optional[ClientHello]:
# Check if ClientHello is complete
client_hello = get_client_hello(data)
if client_hello:
return ClientHello(client_hello[4:])
return net_tls.ClientHello(client_hello[4:])
return None
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
class EstablishServerTLS(commands.ConnectionCommand):
connection: context.Server
blocking = True
@ -99,9 +100,17 @@ class EstablishServerTLSReply(events.CommandReply):
"""error message"""
class TlsStart:
def __init__(self, conn: context.Connection, context: context.Context) -> None:
self.conn = conn
self.context = context
self.ssl_conn = None
class _TLSLayer(layer.Layer):
tls: Dict[context.Connection, SSL.Connection]
child_layer: layer.Layer
ssl_context: Optional[SSL.Context] = None
def __init__(self, context: context.Context):
super().__init__(context)
@ -140,15 +149,18 @@ class _TLSLayer(layer.Layer):
except SSL.WantReadError:
yield from self.tls_interact(conn)
return False, None
except SSL.ZeroReturnError as e:
except SSL.Error as e:
return False, repr(e)
else:
conn.tls_established = True
conn.sni = self.tls[conn].get_servername()
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
conn.cipher_list = self.tls[conn].get_cipher_list()
conn.tls_version = self.tls[conn].get_protocol_version_name()
conn.timestamp_tls_setup = time.time()
yield commands.Log(f"TLS established: {conn}")
yield from self.receive(conn, b"")
# TODO: Set all other connection attributes here
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
return True, None
def receive(self, conn: context.Connection, data: bytes):
@ -213,8 +225,8 @@ class ServerTLSLayer(_TLSLayer):
self.command_to_reply_to = {}
self.child_layer = layer.NextLayer(self.context)
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
commands.Command, Any, Tuple[bool, Optional[str]]]:
def negotiate(self, conn: context.Connection, data: bytes) \
-> Generator[commands.Command, Any, Tuple[bool, Optional[str]]]:
done, err = yield from super().negotiate(conn, data)
if done or err:
cmd = self.command_to_reply_to.pop(conn)
@ -232,19 +244,11 @@ class ServerTLSLayer(_TLSLayer):
def start_server_tls(self, conn: context.Server):
assert conn not in self.tls
assert conn.connected
conn.tls = True
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
if conn.alpn_offers:
ssl_context.set_alpn_protos(conn.alpn_offers)
self.tls[conn] = SSL.Connection(ssl_context)
if conn.sni:
if conn.sni is True:
if self.context.client.sni:
conn.sni = self.context.client.sni
else:
conn.sni = conn.address[0].encode()
self.tls[conn].set_tlsext_host_name(conn.sni)
tls_start = TlsStart(conn, self.context)
yield commands.Hook("tls_start", tls_start)
self.tls[conn] = tls_start.ssl_conn
self.tls[conn].set_connect_state()
yield from self.negotiate(conn, b"")
@ -274,6 +278,7 @@ class ClientTLSLayer(_TLSLayer):
super().__init__(context)
self.recv_buffer = bytearray()
self.child_layer = layer.NextLayer(self.context)
self._handle_event = self.state_start
@expect(events.Start)
def state_start(self, _) -> commands.TCommandGenerator:
@ -281,9 +286,6 @@ class ClientTLSLayer(_TLSLayer):
self._handle_event = self.state_wait_for_clienthello
yield from ()
_handle_event = state_start
@expect(events.DataReceived, events.ConnectionClosed)
def state_wait_for_clienthello(self, event: events.Event):
client = self.context.client
if isinstance(event, events.DataReceived) and event.connection == client:
@ -296,8 +298,7 @@ class ClientTLSLayer(_TLSLayer):
if client_hello:
yield commands.Log(f"Client Hello: {client_hello}")
# TODO: Don't do double conversion
client.sni = client_hello.sni.encode("idna")
client.sni = client_hello.sni
client.alpn_offers = client_hello.alpn_protocols
client_tls_requires_server_connection = (
@ -322,8 +323,10 @@ class ClientTLSLayer(_TLSLayer):
# In any case, we now have enough information to start server TLS if needed.
yield from self.event_to_child(events.Start())
elif isinstance(event, events.ConnectionClosed) and event.connection == client:
self.recv_buffer.clear()
else:
raise NotImplementedError(event) # TODO
yield from self.event_to_child(event)
def start_server_tls(self):
"""
@ -339,11 +342,6 @@ class ClientTLSLayer(_TLSLayer):
)
return err
server.alpn_offers = [
x for x in self.context.client.alpn_offers
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
]
err = yield EstablishServerTLS(server)
if err:
yield commands.Log(
@ -352,36 +350,10 @@ class ClientTLSLayer(_TLSLayer):
return err
def start_client_tls(self) -> commands.TCommandGenerator:
# FIXME: Do this properly. Also adjust error message in negotiate()
client = self.context.client
server = self.context.server
context = SSL.Context(SSL.SSLv23_METHOD)
cert, privkey, cert_chain = CertStore.from_store(
os.path.expanduser("~/.mitmproxy"), "mitmproxy",
self.context.options.key_size
).get_cert(client.sni, (client.sni,))
context.use_privatekey(privkey)
context.use_certificate(cert.x509)
context.set_cipher_list(DEFAULT_CLIENT_CIPHERS)
def alpn_select_callback(conn_, options):
if server.alpn in options:
return server.alpn
elif b"h2" in options:
return b"h2"
elif b"http/1.1" in options:
return b"http/1.1"
elif b"http/1.0" in options:
return b"http/1.0"
elif b"http/0.9" in options:
return b"http/0.9"
else:
# FIXME: We MUST return something here. At this point we are at loss.
return options[0]
context.set_alpn_select_callback(alpn_select_callback)
self.tls[client] = SSL.Connection(context)
tls_start = TlsStart(client, self.context)
yield commands.Hook("tls_start", tls_start)
self.tls[client] = tls_start.ssl_conn
self.tls[client].set_accept_state()
yield from self.negotiate(client, bytes(self.recv_buffer))
@ -390,11 +362,16 @@ class ClientTLSLayer(_TLSLayer):
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
done, err = yield from super().negotiate(conn, data)
if err:
if self.context.client.sni:
# TODO: Also use other sources than SNI
dest = " for " + self.context.client.sni.decode("idna")
else:
dest = ""
yield commands.Log(
f"Client TLS Handshake failed. "
f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).",
f"The client may not trust the proxy's certificate{dest} ({err}).",
level="warn"
# TODO: Also use other sources than SNI
)
yield commands.CloseConnection(self.context.client)
return done

View File

@ -24,7 +24,7 @@ def test_open_connection(tctx):
def test_open_connection_err(tctx):
f = Placeholder()
assert (
playbook(TCPLayer(tctx), hooks=True)
playbook(TCPLayer(tctx))
<< Hook("tcp_start", f)
>> reply()
<< OpenConnection(tctx.server)
@ -40,7 +40,7 @@ def test_simple(tctx):
f = Placeholder()
assert (
playbook(TCPLayer(tctx), hooks=True)
playbook(TCPLayer(tctx))
<< Hook("tcp_start", f)
>> reply()
<< OpenConnection(tctx.server)
@ -71,7 +71,7 @@ def test_receive_data_before_server_connected(tctx):
will still be forwarded.
"""
assert (
playbook(TCPLayer(tctx))
playbook(TCPLayer(tctx), hooks=False)
<< OpenConnection(tctx.server)
>> DataReceived(tctx.client, b"hello!")
>> reply(None, to=-2)

View File

@ -3,11 +3,15 @@ import ssl
import typing
import pytest
from OpenSSL import SSL
from mitmproxy.proxy2 import context, events, commands
from mitmproxy.proxy2 import commands, context, events
from mitmproxy.proxy2.layers import tls
from mitmproxy.utils import data
from test.mitmproxy.proxy2 import tutils
tlsdata = data.Data(__name__)
def test_is_tls_handshake_record():
assert tls.is_tls_handshake_record(bytes.fromhex("160300"))
@ -33,11 +37,11 @@ def test_record_contents():
def test_record_contents_err():
with pytest.raises(ValueError, msg="Expected TLS record"):
with pytest.raises(ValueError, match="Expected TLS record"):
next(tls.handshake_record_contents(b"GET /error"))
empty_record = bytes.fromhex("1603010000")
with pytest.raises(ValueError, msg="Record must not be empty"):
with pytest.raises(ValueError, match="Record must not be empty"):
next(tls.handshake_record_contents(empty_record))
@ -65,7 +69,8 @@ def test_get_client_hello():
class SSLTest:
"""Helper container for Python's builtin SSL object."""
def __init__(self, server_side=False, alpn=None):
def __init__(self, server_side: bool = False, alpn: typing.List[bytes] = None,
sni: typing.Optional[bytes] = b"example.com"):
self.inc = ssl.MemoryBIO()
self.out = ssl.MemoryBIO()
self.ctx = ssl.SSLContext()
@ -77,83 +82,78 @@ class SSLTest:
self.obj = self.ctx.wrap_bio(
self.inc,
self.out,
server_hostname=None if server_side else "example.com",
server_hostname=None if server_side else sni,
server_side=server_side,
)
def _test_tls_client_server(
tctx: context.Context,
alpn: typing.Optional[str]
) -> typing.Tuple[tutils.playbook[tls.ClientTLSLayer], SSLTest]:
layer = tls.ClientTLSLayer(tctx)
playbook = tutils.playbook(layer)
tctx.server.tls = True
tctx.server.address = ("example.com", 443)
tssl_client = SSLTest(alpn=alpn)
# Handshake
assert (
playbook
<< None
)
with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake()
client_hello = tssl_client.out.read()
assert (
playbook
>> events.DataReceived(tctx.client, client_hello[:42])
<< None
)
# Still waiting...
# Finish sending ClientHello
playbook >> events.DataReceived(tctx.client, client_hello[42:])
return playbook, tssl_client
def echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
def _test_echo(playbook: tutils.playbook, tssl: SSLTest, conn: context.Connection) -> None:
tssl.obj.write(b"Hello World")
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(conn, tssl.out.read())
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(tutils.EchoLayer)
<< commands.SendData(conn, data)
)
tssl.inc.write(data())
assert tssl.obj.read() == b"hello world"
class TlsEchoLayer(tutils.EchoLayer):
err: typing.Optional[str] = None
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.DataReceived) and event.data == b"establish-server-tls":
# noinspection PyTypeChecker
self.err = yield tls.EstablishServerTLS(self.context.server)
else:
yield from super()._handle_event(event)
def interact(playbook: tutils.playbook, conn: context.Connection, tssl: SSLTest):
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(conn, tssl.out.read())
<< commands.SendData(conn, data)
)
tssl.inc.write(data())
def reply_tls_start(*args, **kwargs) -> tutils.reply:
"""
Helper function to simplify the syntax for tls_start hooks.
"""
def make_conn(hook: commands.Hook) -> None:
tls_start = hook.data
assert isinstance(tls_start, tls.TlsStart)
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
if tls_start.conn == tls_start.context.client:
ssl_context.use_privatekey_file(
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.key")
)
ssl_context.use_certificate_chain_file(
tlsdata.path("../../net/data/verificationcerts/trusted-leaf.crt")
)
tls_start.ssl_conn = SSL.Connection(ssl_context)
return tutils.reply(*args, side_effect=make_conn, **kwargs)
class TestServerTLS:
def test_no_tls(self, tctx: context.Context):
"""Test TLS layer without TLS"""
layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer)
layer.child_layer = TlsEchoLayer(tctx)
# Handshake
assert (
playbook
tutils.playbook(layer)
>> events.DataReceived(tctx.client, b"Hello World")
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(tutils.EchoLayer)
<< commands.SendData(tctx.client, b"hello world")
)
def test_no_connection(self, tctx):
"""
The server TLS layer is initiated, but there is no active connection yet, so nothing
should be done.
"""
layer = tls.ServerTLSLayer(tctx)
playbook = tutils.playbook(layer)
tctx.server.tls = True
# We did not have a server connection before, so let's do nothing.
assert (
playbook
<< None
>> events.DataReceived(tctx.server, b"Foo")
<< commands.SendData(tctx.server, b"foo")
)
def test_simple(self, tctx):
@ -161,7 +161,6 @@ class TestServerTLS:
playbook = tutils.playbook(layer)
tctx.server.connected = True
tctx.server.address = ("example.com", 443)
tctx.server.tls = True
tssl = SSLTest(server_side=True)
@ -169,6 +168,11 @@ class TestServerTLS:
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.client, b"establish-server-tls")
<< commands.Hook("next_layer", tutils.Placeholder())
>> tutils.next_layer(TlsEchoLayer)
<< commands.Hook("tls_start", tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.server, data)
)
@ -176,13 +180,7 @@ class TestServerTLS:
tssl.inc.write(data())
with pytest.raises(ssl.SSLWantReadError):
tssl.obj.do_handshake()
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.server, tssl.out.read())
<< commands.SendData(tctx.server, data)
)
tssl.inc.write(data())
interact(playbook, tctx.server, tssl)
# finish server handshake
tssl.obj.do_handshake()
@ -193,134 +191,141 @@ class TestServerTLS:
)
assert tctx.server.tls_established
assert tctx.server.sni == b"example.com"
# Echo
echo(playbook, tssl, tctx.server)
_test_echo(playbook, tssl, tctx.server)
def _make_client_tls_layer(tctx: context.Context) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer]:
# This is a bit contrived as the client layer expects a server layer as parent.
# We also set child layers manually to avoid NextLayer noise.
server_layer = tls.ServerTLSLayer(tctx)
client_layer = tls.ClientTLSLayer(tctx)
server_layer.child_layer = client_layer
client_layer.child_layer = TlsEchoLayer(tctx)
playbook = tutils.playbook(server_layer)
return playbook, client_layer
def _test_tls_client_server(
tctx: context.Context,
sni: typing.Optional[bytes]
) -> typing.Tuple[tutils.playbook, tls.ClientTLSLayer, SSLTest]:
playbook, client_layer = _make_client_tls_layer(tctx)
tctx.server.tls = True
tctx.server.address = ("example.com", 443)
tssl_client = SSLTest(sni=sni)
# Send ClientHello
with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake()
return playbook, client_layer, tssl_client
class TestClientTLS:
def test_simple(self, tctx: context.Context):
def test_client_only(self, tctx: context.Context):
"""Test TLS with client only"""
layer = tls.ClientTLSLayer(tctx)
playbook = tutils.playbook(layer)
playbook, client_layer = _make_client_tls_layer(tctx)
tssl = SSLTest()
assert not tctx.client.tls_established
# Handshake
assert playbook
assert layer._handle_event == layer.state_wait_for_clienthello
def interact():
# Start Handshake, send ClientHello and ServerHello
with pytest.raises(ssl.SSLWantReadError):
tssl.obj.do_handshake()
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.client, tssl.out.read())
<< commands.Hook("tls_start", tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.client, data)
)
tssl.inc.write(data())
try:
tssl.obj.do_handshake()
except ssl.SSLWantReadError:
return False
else:
return True
# receive ClientHello, send ServerHello
with pytest.raises(ssl.SSLWantReadError):
tssl.obj.do_handshake()
assert not interact()
# Finish Handshake
assert interact()
tssl.obj.do_handshake()
interact(playbook, tctx.client, tssl)
assert layer._handle_event == layer.state_process
assert tssl.obj.getpeercert(True)
assert tctx.client.tls_established
# Echo
echo(playbook, tssl, tctx.client)
_test_echo(playbook, tssl, tctx.client)
assert (
playbook
>> events.DataReceived(tctx.server, b"Hello")
<< commands.SendData(tctx.server, b"hello")
>> events.DataReceived(tctx.server, b"Plaintext")
<< commands.SendData(tctx.server, b"plaintext")
)
def test_no_server_conn_required(self, tctx):
def test_server_not_required(self, tctx):
"""
Here we test the scenario where a server connection is _not_ required
to establish TLS with the client. After determining this when parsing the ClientHello,
we only establish a connection with the client. The server connection may ultimately
be established when OpenConnection is called.
"""
playbook, _ = _test_tls_client_server(tctx, None)
playbook, client_layer, tssl = _test_tls_client_server(tctx, sni=b"example.com")
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.client, tssl.out.read())
<< commands.Hook("tls_start", tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.client, data)
)
assert data()
assert playbook.layer._handle_event == playbook.layer.state_process
tssl.inc.write(data())
tssl.obj.do_handshake()
interact(playbook, tctx.client, tssl)
assert tctx.client.tls_established
def test_alpn(self, tctx):
def test_server_required(self, tctx):
"""
Here we test the scenario where a server connection is required (e.g. because of ALPN negotation)
Here we test the scenario where a server connection is required (because SNI is missing)
to establish TLS with the client.
"""
tssl_server = SSLTest(server_side=True, alpn=["foo", "bar"])
playbook, tssl_client = _test_tls_client_server(tctx, ["qux", "foo"])
tssl_server = SSLTest(server_side=True)
playbook, client_layer, tssl_client = _test_tls_client_server(tctx, sni=None)
# We should now get instructed to open a server connection.
assert (
playbook
<< commands.OpenConnection(tctx.server)
)
tctx.server.connected = True
data = tutils.Placeholder()
assert (
playbook
>> events.OpenConnectionReply(-1, None)
>> events.DataReceived(tctx.client, tssl_client.out.read())
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
<< commands.Hook("tls_start", tutils.Placeholder())
>> reply_tls_start()
<< commands.SendData(tctx.server, data)
)
assert playbook.layer._handle_event == playbook.layer.state_wait_for_server_tls
assert playbook.layer.child_layer.tls[tctx.server]
# Establish TLS with the server...
tssl_server.inc.write(data())
with pytest.raises(ssl.SSLWantReadError):
tssl_server.obj.do_handshake()
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.server, tssl_server.out.read())
<< commands.SendData(tctx.server, data)
<< commands.Hook("tls_start", tutils.Placeholder())
)
tssl_server.inc.write(data())
tssl_server.obj.do_handshake()
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.server, tssl_server.out.read())
<< commands.SendData(tctx.client, data)
)
assert playbook.layer._handle_event == playbook.layer.state_process
assert tctx.server.tls_established
# Server TLS is established, we can now reply to the client handshake...
tssl_client.inc.write(data())
with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake()
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.client, tssl_client.out.read())
>> reply_tls_start()
<< commands.SendData(tctx.client, data)
)
tssl_client.inc.write(data())
tssl_client.obj.do_handshake()
interact(playbook, tctx.client, tssl_client)
# Both handshakes completed!
assert tctx.client.tls_established
assert tctx.server.tls_established
assert tssl_client.obj.selected_alpn_protocol() == "foo"
assert tssl_server.obj.selected_alpn_protocol() == "foo"
_test_echo(playbook, tssl_server, tctx.server)
_test_echo(playbook, tssl_client, tctx.client)

View File

@ -5,13 +5,13 @@ from test.mitmproxy.proxy2 import tutils
class TestNextLayer:
def test_simple(self, tctx):
nl = layer.NextLayer(tctx)
playbook = tutils.playbook(nl)
playbook = tutils.playbook(nl, hooks=True)
assert (
playbook
>> events.DataReceived(tctx.client, b"foo")
<< commands.Hook("next_layer", nl)
>> events.HookReply(-1)
>> tutils.reply()
>> events.DataReceived(tctx.client, b"bar")
<< commands.Hook("next_layer", nl)
)
@ -21,7 +21,7 @@ class TestNextLayer:
nl.layer = tutils.EchoLayer(tctx)
assert (
playbook
>> events.HookReply(-1)
>> tutils.reply()
<< commands.SendData(tctx.client, b"foo")
<< commands.SendData(tctx.client, b"bar")
)
@ -45,7 +45,7 @@ class TestNextLayer:
assert (
playbook
>> events.HookReply(-2)
>> tutils.reply(to=-2)
<< commands.SendData(tctx.client, b"foo")
<< commands.SendData(tctx.client, b"bar")
)
@ -63,7 +63,7 @@ class TestNextLayer:
handle = nl.handle_event
assert (
playbook
>> events.HookReply(-1)
>> tutils.reply()
<< commands.SendData(tctx.client, b"foo")
)
sd, = handle(events.DataReceived(tctx.client, b"bar"))

View File

@ -38,7 +38,7 @@ class TLayer(Layer):
@pytest.fixture
def tplaybook(tctx):
return tutils.playbook(TLayer(tctx), [])
return tutils.playbook(TLayer(tctx), expected=[])
def test_simple(tplaybook):
@ -158,7 +158,7 @@ def test_command_reply(tplaybook):
tplaybook
>> TEvent()
<< TCommand()
>> TCommandReply(-1, 42)
>> tutils.reply(42)
)
assert tplaybook.actual[1] == tplaybook.actual[2].command

View File

@ -2,6 +2,7 @@ import collections.abc
import copy
import difflib
import itertools
import sys
import typing
from mitmproxy.proxy2 import commands, context
@ -101,7 +102,7 @@ class playbook:
def __init__(
self,
layer: Layer,
hooks: bool = False,
hooks: bool = True,
logs: bool = False,
expected: typing.Optional[TPlaybook] = None,
):
@ -196,13 +197,13 @@ class playbook:
class reply(events.Event):
args: typing.Tuple[typing.Any, ...]
to: typing.Union[commands.Command, int]
side_effect: typing.Callable[[commands.Command], typing.Any]
side_effect: typing.Callable[[typing.Any], typing.Any]
def __init__(
self,
*args,
to: typing.Union[commands.Command, int] = -1,
side_effect: typing.Callable[[commands.Command], typing.Any] = lambda cmd: None
side_effect: typing.Callable[[typing.Any], None] = lambda x: None
):
"""Utility method to reply to the latest hook in playbooks."""
self.args = args
@ -226,6 +227,7 @@ class reply(events.Event):
actual_str = "\n".join(_fmt_entry(x) for x in playbook.actual)
raise AssertionError(f"Expected command ({self.to}) did not occur:\n{actual_str}")
assert isinstance(self.to, commands.Command)
self.side_effect(self.to)
reply_cls = command_reply_subclasses[type(self.to)]
try:
@ -272,14 +274,16 @@ def Placeholder() -> typing.Any:
class EchoLayer(Layer):
"""Echo layer that sends all data back to the client in lowercase."""
def _handle_event(self, event: events.Event):
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.DataReceived):
yield commands.SendData(event.connection, event.data.lower())
def next_layer(
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]]
) -> events.HookReply:
layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]],
*args,
**kwargs
) -> reply:
"""
Helper function to simplify the syntax for next_layer events from this:
@ -294,21 +298,10 @@ def next_layer(
<< commands.Hook("next_layer", next_layer)
>> tutils.next_layer(next_layer, tutils.EchoLayer)
>> tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context)
"""
raise RuntimeError("Does tutils.reply(side_effect=lambda cmd: cmd.layer = tutils.EchoLayer(cmd.data.context) work?")
if isinstance(layer, type):
def make_layer(ctx: context.Context) -> Layer:
return layer(ctx)
else:
make_layer = layer
def set_layer(playbook: playbook) -> None:
last_command = playbook.actual[-1]
assert isinstance(last_command, commands.Hook)
assert isinstance(last_command.data, NextLayer)
last_command.data.layer = make_layer(last_command.data.context)
def set_layer(hook: commands.Hook) -> None:
assert isinstance(hook.data, NextLayer)
hook.data.layer = layer(hook.data.context)
reply = events.HookReply(-1)
reply._playbook_eval = set_layer
return reply
return reply(*args, side_effect=set_layer, **kwargs)