[sans-io] split tls layer into client and server layers

this drastically reduces the complexity of the TLS code and
makes it easier to implement the remaining bits.
This commit is contained in:
Maximilian Hils 2017-12-16 00:59:12 +01:00
parent 34274744a1
commit 8f3db90def
6 changed files with 317 additions and 275 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional, List, Union
from typing import Optional, List, Union, Sequence
from mitmproxy.options import Options
@ -10,7 +10,9 @@ class Connection:
address: tuple
connected: bool = False
tls: bool = False
tls_established: bool = False
alpn: Optional[bytes] = None
alpn_offers: Sequence[bytes] = ()
def __repr__(self):
return f"{type(self).__name__}({repr(self.__dict__)})"

View File

@ -1,13 +1,13 @@
from . import modes
from .http import HTTPLayer
from .tcp import TCPLayer
from .tls import TLSLayer
from .tls import ClientTLSLayer, ServerTLSLayer
from .websocket import WebsocketLayer
__all__ = [
"modes",
"HTTPLayer",
"TCPLayer",
"TLSLayer",
"ClientTLSLayer", "ServerTLSLayer",
"WebsocketLayer"
]

View File

@ -6,9 +6,12 @@ from mitmproxy.proxy2.context import Context, Server
class ReverseProxy(layer.Layer):
def __init__(self, context: Context):
super().__init__(context)
server_addr = server_spec.parse_with_mode(context.options.mode)[1].address
self.context.server = Server(server_addr)
spec = server_spec.parse_with_mode(context.options.mode)[1]
self.context.server = Server(spec.address)
if spec.scheme != "http":
self.context.server.tls = True
if not context.options.keep_host_header:
self.context.server.sni = spec.address[0]
child_layer = layer.NextLayer(self.context)
self._handle_event = child_layer.handle_event

View File

@ -1,7 +1,6 @@
import os
import struct
from enum import Enum
from typing import MutableMapping, Optional, Iterator
from typing import MutableMapping, Optional, Iterator, Union, Generator, Any
from OpenSSL import SSL
@ -13,14 +12,6 @@ from mitmproxy.proxy2 import layer, commands, events
from mitmproxy.proxy2.utils import expect
class ConnectionState(Enum):
NO_TLS = 1
WAIT_FOR_CLIENTHELLO = 2
WAIT_FOR_SERVER_TLS = 3
NEGOTIATING = 5
ESTABLISHED = 6
def is_tls_handshake_record(d: bytes) -> bool:
"""
Returns:
@ -78,265 +69,34 @@ def get_client_hello(data: bytes) -> Optional[bytes]:
return None
class TLSLayer(layer.Layer):
def parse_client_hello(data: bytes) -> Optional[TlsClientHello]:
"""
The TLS layer manages both client- and server-side TLS connection state.
This unfortunately is quite complex as the client handshake may depend on the server
handshake and vice versa: We need the client's SNI and ALPN to connect upstream,
and we need the server's ALPN choice to complete our client TLS handshake.
On top, we may have configurations where TLS is only added on one end,
and we also may have OpenConnection events which change the server's TLS configuration.
Check if the supplied bytes contain a full ClientHello message,
and if so, parse it.
Returns:
- A ClientHello object on success
- None, if the TLS record is not complete
The following state machine shows the possible states for client and server connection:
Legend:
/: NO_TLS
WCH: WAIT_FOR_CLIENTHELLO
WST: WAIT_FOR_SERVER_TLS
N: NEGOTIATING
E: ESTABLISHED
+------------+ +---+ +------------+
|Client State|--------> | / | |Server State|
+------------+ no tls +---+ server tls, +------------+ server tls,
| client tls | | | no client tls
v client tls +-----------------+ | +--------------------+
| | |
+------------------------------+ | | no server tls |
| no server tls | v v v
| v OpenConn(TLS)
v server tls +---+ not needed +--------------------> +---+
+---> +---+ |WCH+-------------> | / | | N | <-+
+---+ | | N | +---+ +---+ <--------------------+ |
|WCH| | +> +---+ | OpenConn(No TLS) | |
+---+ | | | | ^ | |
| | | v | already connec- | handshake done v |
|ClientHello arrives | | +---+ | ted or server | |
| | | | E | | info needed | OpenConn(No TLS) +---+ |
+----------------------+ | +---+ | +--------------------+ E | |
| no server info needed | | +---+ |
| | | |
v server info needed | +------------------------------------------------+
|
+---+ |
|WST|-----------------------+
+---+ server tls established
(or errored)
Raises:
- A ValueError, if the passed ClientHello is invalid
"""
tls: MutableMapping[context.Connection, SSL.Connection]
state: MutableMapping[context.Connection, ConnectionState]
recv_buffer: MutableMapping[context.Connection, bytearray]
client_hello: Optional[TlsClientHello]
child_layer: layer.Layer
def __init__(self, context: context.Context):
super().__init__(context)
self.tls = {}
self.state = {}
self.recv_buffer = {
context.client: bytearray(),
context.server: bytearray()
}
self.client_hello = None
self.child_layer = layer.NextLayer(context)
@expect(events.Start)
def start(self, _) -> commands.TCommandGenerator:
client = self.context.client
server = self.context.server
if client.tls and server.tls:
self.state[client] = ConnectionState.WAIT_FOR_CLIENTHELLO
self.state[server] = ConnectionState.WAIT_FOR_CLIENTHELLO
elif client.tls:
self.state[server] = ConnectionState.NO_TLS
yield from self.start_client_tls()
elif server.tls and server.connected:
self.state[client] = ConnectionState.NO_TLS
yield from self.start_server_tls()
else:
self.state[client] = ConnectionState.NO_TLS
self.state[server] = ConnectionState.NO_TLS
yield from self.child_layer.handle_event(events.Start())
self._handle_event = self.process
_handle_event = start
def send(self, send_command: commands.SendData) -> commands.TCommandGenerator:
if self.state[send_command.connection] == ConnectionState.NO_TLS:
yield send_command
else:
yield commands.Log(f"Plain{send_command}")
self.tls[send_command.connection].sendall(send_command.data)
yield from self.tls_interact(send_command.connection)
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
for command in self.child_layer.handle_event(event):
if isinstance(command, commands.SendData):
yield from self.send(command)
elif isinstance(command, commands.OpenConnection):
raise NotImplementedError("Cannot open connection")
else:
yield command
def parse_client_hello(self):
# Check if ClientHello is complete
client_hello = get_client_hello(self.recv_buffer[self.context.client])
client_hello = get_client_hello(data)
if client_hello:
self.client_hello = TlsClientHello(client_hello[4:])
return True
return False
return TlsClientHello(client_hello[4:])
return None
def process(self, event: events.Event):
if isinstance(event, events.DataReceived):
state = self.state[event.connection]
if state is ConnectionState.WAIT_FOR_CLIENTHELLO:
yield from self.process_wait_for_clienthello(event)
elif state is ConnectionState.WAIT_FOR_SERVER_TLS:
self.recv_buffer[self.context.client].extend(event.data)
elif state is ConnectionState.NEGOTIATING:
yield from self.process_negotiate(event)
elif state is ConnectionState.NO_TLS:
yield from self.event_to_child(event)
elif state is ConnectionState.ESTABLISHED:
yield from self.process_relay(event)
else:
raise RuntimeError("Unexpected state")
else:
yield from self.event_to_child(event)
class _TLSLayer(layer.Layer):
send_buffer: MutableMapping[SSL.Connection, bytearray]
tls: MutableMapping[context.Connection, SSL.Connection]
child_layer: Optional[layer.Layer] = None
def process_wait_for_clienthello(self, event: events.DataReceived):
client = self.context.client
server = self.context.server
# We are not ready to process this yet.
self.recv_buffer[event.connection].extend(event.data)
if event.connection == client and self.parse_client_hello():
yield commands.Log(f"Client Hello: {self.client_hello}")
client_tls_requires_server_connection = (
self.context.server.tls and
self.context.options.upstream_cert and
(
self.context.options.add_upstream_certs_to_client_chain or
self.client_hello.alpn_protocols or
not self.client_hello.sni
)
)
# What do we do with the client connection now?
if client_tls_requires_server_connection:
self.state[client] = ConnectionState.WAIT_FOR_SERVER_TLS
else:
yield from self.start_client_tls()
# What do we do with the server connection now?
if client_tls_requires_server_connection and not self.context.server.connected:
yield commands.OpenConnection(self.context.server)
if not self.context.server.connected:
self.state[server] = ConnectionState.NO_TLS
else:
yield from self.start_server_tls()
def process_negotiate(self, event: events.DataReceived):
# bio_write errors for b"", so we need to check first if we actually received something.
if event.data:
self.tls[event.connection].bio_write(event.data)
try:
self.tls[event.connection].do_handshake()
except SSL.WantReadError:
yield from self.tls_interact(event.connection)
else:
self.state[event.connection] = ConnectionState.ESTABLISHED
event.connection.sni = self.tls[event.connection].get_servername()
event.connection.alpn = self.tls[event.connection].get_alpn_proto_negotiated()
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
yield from self.process(events.DataReceived(event.connection, b""))
if self.state[self.context.client] == ConnectionState.WAIT_FOR_SERVER_TLS:
assert event.connection == self.context.server
yield from self.start_client_tls()
def process_relay(self, event: events.DataReceived):
if event.data:
self.tls[event.connection].bio_write(event.data)
yield from self.tls_interact(event.connection)
plaintext = bytearray()
while True:
try:
plaintext.extend(self.tls[event.connection].recv(65535))
except (SSL.WantReadError, SSL.ZeroReturnError):
break
if plaintext:
evt = events.DataReceived(event.connection, bytes(plaintext))
yield commands.Log(f"Plain{evt}")
yield from self.event_to_child(evt)
def start_server_tls(self):
server = self.context.server
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
if self.client_hello:
alpn = [
x for x in self.client_hello.alpn_protocols
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
]
ssl_context.set_alpn_protos(alpn)
self.tls[server] = SSL.Connection(ssl_context)
if server.sni:
if server.sni is True:
if self.client_hello and self.client_hello.sni:
server.sni = self.client_hello.sni.encode("idna")
else:
server.sni = server.address[0].encode("idna")
self.tls[server].set_tlsext_host_name(server.sni)
self.tls[server].set_connect_state()
self.state[server] = ConnectionState.NEGOTIATING
yield from self.process(events.DataReceived(
server, bytes(self.recv_buffer[server])
))
self.recv_buffer[server] = bytearray()
def start_client_tls(self):
# FIXME
client = self.context.client
server = self.context.server
context = SSL.Context(SSL.SSLv23_METHOD)
cert, privkey, cert_chain = CertStore.from_store(
os.path.expanduser("~/.mitmproxy"), "mitmproxy"
).get_cert(b"example.com", (b"example.com",))
context.use_privatekey(privkey)
context.use_certificate(cert.x509)
context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS)
if self.state[server] == ConnectionState.ESTABLISHED:
alpn_for_client = self.tls[server].get_alpn_proto_negotiated()
def alpn_select_callback(conn_, options):
if alpn_for_client in options:
return alpn_for_client
context.set_alpn_select_callback(alpn_select_callback)
self.tls[client] = SSL.Connection(context)
self.tls[client].set_accept_state()
self.state[client] = ConnectionState.NEGOTIATING
yield from self.process(events.DataReceived(
client, bytes(self.recv_buffer[client])
))
self.recv_buffer[client] = bytearray()
def __init__(self, context):
super().__init__(context)
self.send_buffer = {}
self.tls = {}
def tls_interact(self, conn: context.Connection):
while True:
@ -347,3 +107,269 @@ class TLSLayer(layer.Layer):
return
else:
yield commands.SendData(conn, data)
def send(
self,
send_command: commands.SendData,
) -> commands.TCommandGenerator:
tls_conn = self.tls[send_command.connection]
if send_command.connection.tls_established:
tls_conn.sendall(send_command.data)
yield from self.tls_interact(send_command.connection)
else:
buf = self.send_buffer.setdefault(tls_conn, bytearray())
buf.extend(send_command.data)
def negotiate(self, event: events.DataReceived) -> Generator[commands.Command, Any, bool]:
"""
Make sure to trigger processing if done!
"""
# bio_write errors for b"", so we need to check first if we actually received something.
tls_conn = self.tls[event.connection]
if event.data:
tls_conn.bio_write(event.data)
try:
tls_conn.do_handshake()
except SSL.WantReadError:
yield from self.tls_interact(event.connection)
return False
else:
event.connection.tls_established = True
event.connection.alpn = tls_conn.get_alpn_proto_negotiated()
print(f"TLS established: {event.connection}")
# TODO: Set all other connection attributes here
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
yield from self.relay(events.DataReceived(event.connection, b""))
if tls_conn in self.send_buffer:
data_to_send = bytes(self.send_buffer.pop(tls_conn))
yield from self.send(commands.SendData(event.connection, data_to_send))
return True
def relay(self, event: events.DataReceived):
tls_conn = self.tls[event.connection]
if event.data:
tls_conn.bio_write(event.data)
yield from self.tls_interact(event.connection)
plaintext = bytearray()
while True:
try:
plaintext.extend(tls_conn.recv(65535))
except (SSL.WantReadError, SSL.ZeroReturnError):
break
if plaintext:
evt = events.DataReceived(event.connection, bytes(plaintext))
# yield commands.Log(f"Plain{evt}")
yield from self.event_to_child(evt)
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
for command in self.child_layer.handle_event(event):
if isinstance(command, commands.SendData) and command.connection in self.tls:
yield from self.send(command)
else:
yield command
class ServerTLSLayer(_TLSLayer):
"""
This layer manages TLS on potentially multiple server connections.
"""
def __init__(self, context: context.Context):
super().__init__(context)
self.child_layer = layer.NextLayer(context)
@expect(events.Start)
def start(self, event: events.Start) -> commands.TCommandGenerator:
yield from self.child_layer.handle_event(event)
server = self.context.server
if server.connected and server.tls:
yield from self._start_tls(server)
self._handle_event = self.process
_handle_event = start
def process(self, event: Union[events.DataReceived, events.ConnectionClosed]):
if isinstance(event, events.DataReceived) and event.connection in self.tls:
if not event.connection.tls_established:
yield from self.negotiate(event)
else:
yield from self.relay(event)
elif isinstance(event, events.OpenConnectionReply):
err = event.reply
conn = event.command.connection
if not err and conn.tls:
yield from self._start_tls(conn)
yield from self.event_to_child(event)
elif isinstance(event, events.ConnectionClosed):
yield from self.event_to_child(event)
self.send_buffer.pop(
self.tls.pop(event.connection, None),
None
)
else:
yield from self.event_to_child(event)
def _start_tls(self, server: context.Server):
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
if server.alpn_offers:
ssl_context.set_alpn_protos(server.alpn_offers)
self.tls[server] = SSL.Connection(ssl_context)
if server.sni:
if server.sni is True:
if self.context.client.sni:
server.sni = self.context.client.sni.encode("idna")
else:
server.sni = server.address[0].encode("idna")
self.tls[server].set_tlsext_host_name(server.sni)
self.tls[server].set_connect_state()
yield from self.process(events.DataReceived(server, b""))
class ClientTLSLayer(_TLSLayer):
"""
This layer establishes TLS on a single client connection.
Start
Wait for ClientHello
Do we need server TLS info
to establish TLS with client?
Wait for Server TLS
yes
no
Process messages
"""
recv_buffer: bytearray
def __init__(self, context: context.Context):
super().__init__(context)
self.recv_buffer = bytearray()
self.child_layer = ServerTLSLayer(self.context)
@expect(events.Start)
def state_start(self, _) -> commands.TCommandGenerator:
self.context.client.tls = True
self._handle_event = self.state_wait_for_clienthello
yield from ()
_handle_event = state_start
@expect(events.DataReceived, events.ConnectionClosed)
def state_wait_for_clienthello(self, event: events.Event):
client = self.context.client
server = self.context.server
if isinstance(event, events.DataReceived) and event.connection == client:
self.recv_buffer.extend(event.data)
try:
client_hello = parse_client_hello(self.recv_buffer)
except ValueError as e:
raise NotImplementedError() from e # TODO
if client_hello:
yield commands.Log(f"Client Hello: {client_hello}")
client.sni = client_hello.sni
client.alpn_offers = client_hello.alpn_protocols
client_tls_requires_server_connection = (
self.context.server.tls and
self.context.options.upstream_cert and
(
self.context.options.add_upstream_certs_to_client_chain or
client.alpn_offers or
not client.sni
)
)
# What do we do with the client connection now?
if client_tls_requires_server_connection and not server.tls_established:
yield from self.start_server_tls()
self._handle_event = self.state_wait_for_server_tls
else:
yield from self.start_negotiate()
self._handle_event = self.state_process
else:
raise NotImplementedError(event) # TODO
def state_wait_for_server_tls(self, event: events.Event):
yield from self.event_to_child(event)
# TODO: Handle case where TLS establishment fails.
# We still need a good way to signal this - one possibility would be by closing
# the connection?
if self.context.server.tls_established:
yield from self.start_negotiate()
self._handle_event = self.state_process
def state_process(self, event: events.Event):
if isinstance(event, events.DataReceived) and event.connection == self.context.client:
if not event.connection.tls_established:
yield from self.negotiate(event)
else:
yield from self.relay(event)
else:
yield from self.event_to_child(event)
def start_server_tls(self):
"""
We often need information from the upstream connection to establish TLS with the client.
For example, we need to check if the client does ALPN or not.
"""
if not self.context.server.connected:
err = yield commands.OpenConnection(self.context.server)
if err:
yield commands.Log(
"Cannot establish server connection, which is required to establish TLS with the client."
)
self.context.server.alpn_offers = [
x for x in self.context.client.alpn_offers
if not (x.startswith(b"h2-") or x.startswith(b"spdy"))
]
yield from self.child_layer.handle_event(events.Start())
def start_negotiate(self):
if not self.child_layer:
yield from self.child_layer.handle_event(events.Start())
# FIXME: Do this properly
client = self.context.client
server = self.context.server
context = SSL.Context(SSL.SSLv23_METHOD)
cert, privkey, cert_chain = CertStore.from_store(
os.path.expanduser("~/.mitmproxy"), "mitmproxy"
).get_cert(b"example.com", (b"example.com",))
context.use_privatekey(privkey)
context.use_certificate(cert.x509)
context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS)
if server.alpn:
def alpn_select_callback(conn_, options):
if server.alpn in options:
return server.alpn
context.set_alpn_select_callback(alpn_select_callback)
self.tls[self.context.client] = SSL.Connection(context)
self.tls[self.context.client].set_accept_state()
yield from self.state_process(events.DataReceived(
client, bytes(self.recv_buffer)
))
self.recv_buffer = bytearray()

View File

@ -59,7 +59,12 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
# self._debug("transports closed!")
async def close_connection(self, connection):
io = self.transports.pop(connection, None)
try:
io = self.transports.pop(connection)
except KeyError:
self.log(f"already closed: {connection}", "warn")
return
else:
self.log(f"closing {connection}", "debug")
try:
await io.w.drain()
@ -109,10 +114,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
print(message)
def server_event(self, event: events.Event) -> None:
self.log(f">> {event}", "debug")
layer_commands = self.layer.handle_event(event)
for command in layer_commands:
self.log(f"<< {command}", "debug")
if isinstance(command, commands.OpenConnection):
asyncio.ensure_future(
self.open_connection(command)
@ -155,13 +158,20 @@ if __name__ == "__main__":
loop = asyncio.get_event_loop()
opts = moptions.Options()
# opts.mode = "reverse:example.com"
opts.mode = "reverse:example.com"
# test client-tls-first scenario
# opts.upstream_cert = False
layers.ClientTLSLayer.debug = ""
layers.ServerTLSLayer.debug = " "
layers.TCPLayer.debug = " "
async def handle(reader, writer):
layer_stack = [
# layers.TLSLayer,
lambda c: layers.HTTPLayer(c, HTTPMode.regular),
layers.ClientTLSLayer,
#layers.ServerTLSLayer,
layers.TCPLayer,
# lambda c: layers.HTTPLayer(c, HTTPMode.transparent),
]
def next_layer(nl: layer.NextLayer):

View File

@ -57,7 +57,8 @@ def expect(*event_types):
yield from f(self, event)
else:
raise TypeError(
"Invalid event type: Expected {}, got {}".format(event_types, event))
"Invalid event type at {}: Expected {}, got {}.".format(f, event_types, event)
)
return wrapper