[sans-io] add ALPN support to TLS layer, tests++

This commit is contained in:
Maximilian Hils 2017-08-16 14:43:55 +02:00
parent 3f7b850268
commit b5a3343d03
2 changed files with 158 additions and 27 deletions

View File

@ -1,11 +1,10 @@
import os import os
import struct import struct
from enum import Enum from enum import Enum
from typing import MutableMapping, Generator, Optional, Iterable, Iterator from typing import MutableMapping, Optional, Iterator
from OpenSSL import SSL from OpenSSL import SSL
from mitmproxy import exceptions
from mitmproxy.certs import CertStore from mitmproxy.certs import CertStore
from mitmproxy.proxy.protocol import TlsClientHello from mitmproxy.proxy.protocol import TlsClientHello
from mitmproxy.proxy.protocol import tls from mitmproxy.proxy.protocol import tls
@ -153,8 +152,8 @@ class TLSLayer(layer.Layer):
self.state[client] = ConnectionState.WAIT_FOR_CLIENTHELLO self.state[client] = ConnectionState.WAIT_FOR_CLIENTHELLO
self.state[server] = ConnectionState.WAIT_FOR_CLIENTHELLO self.state[server] = ConnectionState.WAIT_FOR_CLIENTHELLO
elif client.tls: elif client.tls:
yield from self.start_client_tls()
self.state[server] = ConnectionState.NO_TLS self.state[server] = ConnectionState.NO_TLS
yield from self.start_client_tls()
elif server.tls and server.connected: elif server.tls and server.connected:
self.state[client] = ConnectionState.NO_TLS self.state[client] = ConnectionState.NO_TLS
yield from self.start_server_tls() yield from self.start_server_tls()
@ -186,17 +185,11 @@ class TLSLayer(layer.Layer):
def parse_client_hello(self): def parse_client_hello(self):
# Check if ClientHello is complete # Check if ClientHello is complete
try: client_hello = get_client_hello(self.recv_buffer[self.context.client])
client_hello = get_client_hello(self.recv_buffer[self.context.client])[4:] if client_hello:
self.client_hello = TlsClientHello(client_hello) self.client_hello = TlsClientHello(client_hello[4:])
except exceptions.TlsProtocolException:
return False
except EOFError as e:
raise exceptions.TlsProtocolException(
f'Cannot parse Client Hello: {e}, Raw Client Hello: {client_hello}'
)
else:
return True return True
return False
def process(self, event: events.Event): def process(self, event: events.Event):
if isinstance(event, events.DataReceived): if isinstance(event, events.DataReceived):
@ -235,21 +228,20 @@ class TLSLayer(layer.Layer):
not self.client_hello.sni not self.client_hello.sni
) )
) )
# What do we do with the client connection now?
if client_tls_requires_server_connection and not self.context.server.connected:
yield commands.OpenConnection(self.context.server)
if not self.context.server.connected:
# We are only in the WAIT_FOR_CLIENTHELLO branch if we have two TLS conns.
assert self.context.server.tls
self.state[server] = ConnectionState.NO_TLS
else:
yield from self.start_server_tls()
if client_tls_requires_server_connection: if client_tls_requires_server_connection:
self.state[client] = ConnectionState.WAIT_FOR_SERVER_TLS self.state[client] = ConnectionState.WAIT_FOR_SERVER_TLS
else: else:
yield from self.start_client_tls() yield from self.start_client_tls()
# What do we do with the server connection now?
if client_tls_requires_server_connection and not self.context.server.connected:
yield commands.OpenConnection(self.context.server)
if not self.context.server.connected:
self.state[server] = ConnectionState.NO_TLS
else:
yield from self.start_server_tls()
def process_negotiate(self, event: events.DataReceived): def process_negotiate(self, event: events.DataReceived):
# bio_write errors for b"", so we need to check first if we actually received something. # bio_write errors for b"", so we need to check first if we actually received something.
if event.data: if event.data:
@ -291,16 +283,23 @@ class TLSLayer(layer.Layer):
server = self.context.server server = self.context.server
ssl_context = SSL.Context(SSL.SSLv23_METHOD) ssl_context = SSL.Context(SSL.SSLv23_METHOD)
if self.client_hello:
alpn = [
x for x in self.client_hello.alpn_protocols
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
]
ssl_context.set_alpn_protos(alpn)
self.tls[server] = SSL.Connection(ssl_context) self.tls[server] = SSL.Connection(ssl_context)
if server.sni: if server.sni:
if server.sni is True: if server.sni is True:
if self.client_hello: if self.client_hello and self.client_hello.sni:
server.sni = self.client_hello.sni.encode("idna") server.sni = self.client_hello.sni.encode("idna")
else: else:
server.sni = server.address[0].encode("idna") server.sni = server.address[0].encode("idna")
self.tls[server].set_tlsext_host_name(server.sni) self.tls[server].set_tlsext_host_name(server.sni)
# FIXME: Handle ALPN
self.tls[server].set_connect_state() self.tls[server].set_connect_state()
self.state[server] = ConnectionState.NEGOTIATING self.state[server] = ConnectionState.NEGOTIATING
@ -312,6 +311,7 @@ class TLSLayer(layer.Layer):
def start_client_tls(self): def start_client_tls(self):
# FIXME # FIXME
client = self.context.client client = self.context.client
server = self.context.server
context = SSL.Context(SSL.SSLv23_METHOD) context = SSL.Context(SSL.SSLv23_METHOD)
cert, privkey, cert_chain = CertStore.from_store( cert, privkey, cert_chain = CertStore.from_store(
os.path.expanduser("~/.mitmproxy"), "mitmproxy" os.path.expanduser("~/.mitmproxy"), "mitmproxy"
@ -319,6 +319,16 @@ class TLSLayer(layer.Layer):
context.use_privatekey(privkey) context.use_privatekey(privkey)
context.use_certificate(cert.x509) context.use_certificate(cert.x509)
context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS) context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS)
if self.state[server] == ConnectionState.ESTABLISHED:
alpn_for_client = self.tls[server].get_alpn_proto_negotiated()
def alpn_select_callback(conn_, options):
if alpn_for_client in options:
return alpn_for_client
context.set_alpn_select_callback(alpn_select_callback)
self.tls[client] = SSL.Connection(context) self.tls[client] = SSL.Connection(context)
self.tls[client].set_accept_state() self.tls[client].set_accept_state()

View File

@ -64,14 +64,21 @@ def test_get_client_hello():
class SSLTest: class SSLTest:
"""Helper container for Python's builtin SSL object.""" """Helper container for Python's builtin SSL object."""
def __init__(self, server_side=False): def __init__(self, server_side=False, alpn=None):
self.inc = ssl.MemoryBIO() self.inc = ssl.MemoryBIO()
self.out = ssl.MemoryBIO() self.out = ssl.MemoryBIO()
self.ctx = ssl.SSLContext() self.ctx = ssl.SSLContext()
if alpn:
self.ctx.set_alpn_protocols(alpn)
if server_side: if server_side:
# FIXME: Replace hardcoded location # FIXME: Replace hardcoded location
self.ctx.load_cert_chain(os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem")) self.ctx.load_cert_chain(os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem"))
self.obj = self.ctx.wrap_bio(self.inc, self.out, server_side=server_side) self.obj = self.ctx.wrap_bio(
self.inc,
self.out,
server_hostname=None if server_side else "example.com",
server_side=server_side,
)
def test_no_tls(tctx: context.Context): def test_no_tls(tctx: context.Context):
@ -211,3 +218,117 @@ def test_server_tls(tctx):
# Echo # Echo
echo(playbook, tssl, tctx.server) echo(playbook, tssl, tctx.server)
def _test_tls_client_server(tctx, alpn):
layer = tls.TLSLayer(tctx)
playbook = tutils.playbook(layer)
tctx.client.tls = True
tctx.server.tls = True
tssl_client = SSLTest(alpn=alpn)
# Handshake
assert playbook
assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake()
client_hello = tssl_client.out.read()
assert (
playbook
>> events.DataReceived(tctx.client, client_hello[:42])
<< None
)
# Still waiting...
assert layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
assert layer.state[tctx.server] == tls.ConnectionState.WAIT_FOR_CLIENTHELLO
# Finish sending ClientHello
playbook >> events.DataReceived(tctx.client, client_hello[42:])
return playbook, tssl_client
def test_tls_client_server_no_server_conn(tctx):
"""
Here we test the scenario where a server connection is _not_ required
to establish TLS with the client. After determining this when parsing the ClientHello,
we only establish a connection with the client. The server connection may ultimately
be established when OpenConnection is called.
"""
playbook, _ = _test_tls_client_server(tctx, None)
data = tutils.Placeholder()
assert (
playbook
<< commands.SendData(tctx.client, data)
)
assert data()
assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
assert playbook.layer.state[tctx.server] == tls.ConnectionState.NO_TLS
def test_tls_client_server_alpn(tctx):
"""
Here we test the scenario where a server connection is required (e.g. because of ALPN negotation)
to establish TLS with the client.
"""
tssl_server = SSLTest(server_side=True, alpn=["foo", "bar"])
playbook, tssl_client = _test_tls_client_server(tctx, ["qux", "foo"])
# We should now get instructed to open a server connection.
assert (
playbook
<< commands.OpenConnection(tctx.server)
)
tctx.server.connected = True
data = tutils.Placeholder()
assert (
playbook
>> events.OpenConnectionReply(-1, None)
<< commands.SendData(tctx.server, data)
)
assert playbook.layer.state[tctx.client] == tls.ConnectionState.WAIT_FOR_SERVER_TLS
assert playbook.layer.state[tctx.server] == tls.ConnectionState.NEGOTIATING
# Establish TLS with the server...
tssl_server.inc.write(data())
with pytest.raises(ssl.SSLWantReadError):
tssl_server.obj.do_handshake()
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.server, tssl_server.out.read())
<< commands.SendData(tctx.server, data)
)
tssl_server.inc.write(data())
tssl_server.obj.do_handshake()
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.server, tssl_server.out.read())
<< commands.SendData(tctx.client, data)
)
assert playbook.layer.state[tctx.client] == tls.ConnectionState.NEGOTIATING
assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
# Server TLS is established, we can now reply to the client handshake...
tssl_client.inc.write(data())
with pytest.raises(ssl.SSLWantReadError):
tssl_client.obj.do_handshake()
data = tutils.Placeholder()
assert (
playbook
>> events.DataReceived(tctx.client, tssl_client.out.read())
<< commands.SendData(tctx.client, data)
)
tssl_client.inc.write(data())
tssl_client.obj.do_handshake()
# Both handshakes completed!
assert playbook.layer.state[tctx.client] == tls.ConnectionState.ESTABLISHED
assert playbook.layer.state[tctx.server] == tls.ConnectionState.ESTABLISHED
assert tssl_client.obj.selected_alpn_protocol() == "foo"
assert tssl_server.obj.selected_alpn_protocol() == "foo"