From 8f3db90def7bb0d693ac3ee55613773328763f82 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 16 Dec 2017 00:59:12 +0100 Subject: [PATCH] [sans-io] split tls layer into client and server layers this drastically reduces the complexity of the TLS code and makes it easier to implement the remaining bits. --- mitmproxy/proxy2/context.py | 4 +- mitmproxy/proxy2/layers/__init__.py | 4 +- mitmproxy/proxy2/layers/modes.py | 9 +- mitmproxy/proxy2/layers/tls.py | 548 +++++++++++++++------------- mitmproxy/proxy2/server.py | 24 +- mitmproxy/proxy2/utils.py | 3 +- 6 files changed, 317 insertions(+), 275 deletions(-) diff --git a/mitmproxy/proxy2/context.py b/mitmproxy/proxy2/context.py index c49c6d557..93ef97bc0 100644 --- a/mitmproxy/proxy2/context.py +++ b/mitmproxy/proxy2/context.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Union +from typing import Optional, List, Union, Sequence from mitmproxy.options import Options @@ -10,7 +10,9 @@ class Connection: address: tuple connected: bool = False tls: bool = False + tls_established: bool = False alpn: Optional[bytes] = None + alpn_offers: Sequence[bytes] = () def __repr__(self): return f"{type(self).__name__}({repr(self.__dict__)})" diff --git a/mitmproxy/proxy2/layers/__init__.py b/mitmproxy/proxy2/layers/__init__.py index 4da23f6b4..ff77b606c 100644 --- a/mitmproxy/proxy2/layers/__init__.py +++ b/mitmproxy/proxy2/layers/__init__.py @@ -1,13 +1,13 @@ from . import modes from .http import HTTPLayer from .tcp import TCPLayer -from .tls import TLSLayer +from .tls import ClientTLSLayer, ServerTLSLayer from .websocket import WebsocketLayer __all__ = [ "modes", "HTTPLayer", "TCPLayer", - "TLSLayer", + "ClientTLSLayer", "ServerTLSLayer", "WebsocketLayer" ] diff --git a/mitmproxy/proxy2/layers/modes.py b/mitmproxy/proxy2/layers/modes.py index 21ae205af..a2795aeed 100644 --- a/mitmproxy/proxy2/layers/modes.py +++ b/mitmproxy/proxy2/layers/modes.py @@ -6,9 +6,12 @@ from mitmproxy.proxy2.context import Context, Server class ReverseProxy(layer.Layer): def __init__(self, context: Context): super().__init__(context) - server_addr = server_spec.parse_with_mode(context.options.mode)[1].address - self.context.server = Server(server_addr) - + spec = server_spec.parse_with_mode(context.options.mode)[1] + self.context.server = Server(spec.address) + if spec.scheme != "http": + self.context.server.tls = True + if not context.options.keep_host_header: + self.context.server.sni = spec.address[0] child_layer = layer.NextLayer(self.context) self._handle_event = child_layer.handle_event diff --git a/mitmproxy/proxy2/layers/tls.py b/mitmproxy/proxy2/layers/tls.py index 925a864ea..118b1da2b 100644 --- a/mitmproxy/proxy2/layers/tls.py +++ b/mitmproxy/proxy2/layers/tls.py @@ -1,7 +1,6 @@ import os import struct -from enum import Enum -from typing import MutableMapping, Optional, Iterator +from typing import MutableMapping, Optional, Iterator, Union, Generator, Any from OpenSSL import SSL @@ -13,14 +12,6 @@ from mitmproxy.proxy2 import layer, commands, events from mitmproxy.proxy2.utils import expect -class ConnectionState(Enum): - NO_TLS = 1 - WAIT_FOR_CLIENTHELLO = 2 - WAIT_FOR_SERVER_TLS = 3 - NEGOTIATING = 5 - ESTABLISHED = 6 - - def is_tls_handshake_record(d: bytes) -> bool: """ Returns: @@ -78,265 +69,34 @@ def get_client_hello(data: bytes) -> Optional[bytes]: return None -class TLSLayer(layer.Layer): +def parse_client_hello(data: bytes) -> Optional[TlsClientHello]: """ - The TLS layer manages both client- and server-side TLS connection state. - This unfortunately is quite complex as the client handshake may depend on the server - handshake and vice versa: We need the client's SNI and ALPN to connect upstream, - and we need the server's ALPN choice to complete our client TLS handshake. - On top, we may have configurations where TLS is only added on one end, - and we also may have OpenConnection events which change the server's TLS configuration. + Check if the supplied bytes contain a full ClientHello message, + and if so, parse it. + Returns: + - A ClientHello object on success + - None, if the TLS record is not complete - The following state machine shows the possible states for client and server connection: - - Legend: - /: NO_TLS - WCH: WAIT_FOR_CLIENTHELLO - WST: WAIT_FOR_SERVER_TLS - N: NEGOTIATING - E: ESTABLISHED - - +------------+ +---+ +------------+ - |Client State|--------> | / | |Server State| - +------------+ no tls +---+ server tls, +------------+ server tls, - | client tls | | | no client tls - v client tls +-----------------+ | +--------------------+ - | | | - +------------------------------+ | | no server tls | - | no server tls | v v v - | v OpenConn(TLS) - v server tls +---+ not needed +--------------------> +---+ - +---> +---+ |WCH+-------------> | / | | N | <-+ - +---+ | | N | +---+ +---+ <--------------------+ | - |WCH| | +> +---+ | OpenConn(No TLS) | | - +---+ | | | | ^ | | - | | | v | already connec- | handshake done v | - |ClientHello arrives | | +---+ | ted or server | | - | | | | E | | info needed | OpenConn(No TLS) +---+ | - +----------------------+ | +---+ | +--------------------+ E | | - | no server info needed | | +---+ | - | | | | - v server info needed | +------------------------------------------------+ - | - +---+ | - |WST|-----------------------+ - +---+ server tls established - (or errored) + Raises: + - A ValueError, if the passed ClientHello is invalid """ + # Check if ClientHello is complete + client_hello = get_client_hello(data) + if client_hello: + return TlsClientHello(client_hello[4:]) + return None + + +class _TLSLayer(layer.Layer): + send_buffer: MutableMapping[SSL.Connection, bytearray] tls: MutableMapping[context.Connection, SSL.Connection] - state: MutableMapping[context.Connection, ConnectionState] - recv_buffer: MutableMapping[context.Connection, bytearray] - client_hello: Optional[TlsClientHello] + child_layer: Optional[layer.Layer] = None - child_layer: layer.Layer - - def __init__(self, context: context.Context): + def __init__(self, context): super().__init__(context) + self.send_buffer = {} self.tls = {} - self.state = {} - self.recv_buffer = { - context.client: bytearray(), - context.server: bytearray() - } - self.client_hello = None - - self.child_layer = layer.NextLayer(context) - - @expect(events.Start) - def start(self, _) -> commands.TCommandGenerator: - client = self.context.client - server = self.context.server - - if client.tls and server.tls: - self.state[client] = ConnectionState.WAIT_FOR_CLIENTHELLO - self.state[server] = ConnectionState.WAIT_FOR_CLIENTHELLO - elif client.tls: - self.state[server] = ConnectionState.NO_TLS - yield from self.start_client_tls() - elif server.tls and server.connected: - self.state[client] = ConnectionState.NO_TLS - yield from self.start_server_tls() - else: - self.state[client] = ConnectionState.NO_TLS - self.state[server] = ConnectionState.NO_TLS - - yield from self.child_layer.handle_event(events.Start()) - self._handle_event = self.process - - _handle_event = start - - def send(self, send_command: commands.SendData) -> commands.TCommandGenerator: - if self.state[send_command.connection] == ConnectionState.NO_TLS: - yield send_command - else: - yield commands.Log(f"Plain{send_command}") - self.tls[send_command.connection].sendall(send_command.data) - yield from self.tls_interact(send_command.connection) - - def event_to_child(self, event: events.Event) -> commands.TCommandGenerator: - for command in self.child_layer.handle_event(event): - if isinstance(command, commands.SendData): - yield from self.send(command) - elif isinstance(command, commands.OpenConnection): - raise NotImplementedError("Cannot open connection") - else: - yield command - - def parse_client_hello(self): - # Check if ClientHello is complete - client_hello = get_client_hello(self.recv_buffer[self.context.client]) - if client_hello: - self.client_hello = TlsClientHello(client_hello[4:]) - return True - return False - - def process(self, event: events.Event): - if isinstance(event, events.DataReceived): - state = self.state[event.connection] - - if state is ConnectionState.WAIT_FOR_CLIENTHELLO: - yield from self.process_wait_for_clienthello(event) - elif state is ConnectionState.WAIT_FOR_SERVER_TLS: - self.recv_buffer[self.context.client].extend(event.data) - elif state is ConnectionState.NEGOTIATING: - yield from self.process_negotiate(event) - elif state is ConnectionState.NO_TLS: - yield from self.event_to_child(event) - elif state is ConnectionState.ESTABLISHED: - yield from self.process_relay(event) - else: - raise RuntimeError("Unexpected state") - else: - yield from self.event_to_child(event) - - def process_wait_for_clienthello(self, event: events.DataReceived): - client = self.context.client - server = self.context.server - # We are not ready to process this yet. - self.recv_buffer[event.connection].extend(event.data) - - if event.connection == client and self.parse_client_hello(): - yield commands.Log(f"Client Hello: {self.client_hello}") - - client_tls_requires_server_connection = ( - self.context.server.tls and - self.context.options.upstream_cert and - ( - self.context.options.add_upstream_certs_to_client_chain or - self.client_hello.alpn_protocols or - not self.client_hello.sni - ) - ) - # What do we do with the client connection now? - if client_tls_requires_server_connection: - self.state[client] = ConnectionState.WAIT_FOR_SERVER_TLS - else: - yield from self.start_client_tls() - - # What do we do with the server connection now? - if client_tls_requires_server_connection and not self.context.server.connected: - yield commands.OpenConnection(self.context.server) - if not self.context.server.connected: - self.state[server] = ConnectionState.NO_TLS - else: - yield from self.start_server_tls() - - def process_negotiate(self, event: events.DataReceived): - # bio_write errors for b"", so we need to check first if we actually received something. - if event.data: - self.tls[event.connection].bio_write(event.data) - try: - self.tls[event.connection].do_handshake() - except SSL.WantReadError: - yield from self.tls_interact(event.connection) - else: - self.state[event.connection] = ConnectionState.ESTABLISHED - event.connection.sni = self.tls[event.connection].get_servername() - event.connection.alpn = self.tls[event.connection].get_alpn_proto_negotiated() - - # there might already be data in the OpenSSL BIO, so we need to trigger its processing. - yield from self.process(events.DataReceived(event.connection, b"")) - - if self.state[self.context.client] == ConnectionState.WAIT_FOR_SERVER_TLS: - assert event.connection == self.context.server - yield from self.start_client_tls() - - def process_relay(self, event: events.DataReceived): - if event.data: - self.tls[event.connection].bio_write(event.data) - yield from self.tls_interact(event.connection) - - plaintext = bytearray() - while True: - try: - plaintext.extend(self.tls[event.connection].recv(65535)) - except (SSL.WantReadError, SSL.ZeroReturnError): - break - - if plaintext: - evt = events.DataReceived(event.connection, bytes(plaintext)) - yield commands.Log(f"Plain{evt}") - yield from self.event_to_child(evt) - - def start_server_tls(self): - server = self.context.server - - ssl_context = SSL.Context(SSL.SSLv23_METHOD) - - if self.client_hello: - alpn = [ - x for x in self.client_hello.alpn_protocols - if not (x.startswith(b"h2-") or x.startswith(b"spdy")) - ] - ssl_context.set_alpn_protos(alpn) - - self.tls[server] = SSL.Connection(ssl_context) - - if server.sni: - if server.sni is True: - if self.client_hello and self.client_hello.sni: - server.sni = self.client_hello.sni.encode("idna") - else: - server.sni = server.address[0].encode("idna") - self.tls[server].set_tlsext_host_name(server.sni) - self.tls[server].set_connect_state() - - self.state[server] = ConnectionState.NEGOTIATING - yield from self.process(events.DataReceived( - server, bytes(self.recv_buffer[server]) - )) - self.recv_buffer[server] = bytearray() - - def start_client_tls(self): - # FIXME - client = self.context.client - server = self.context.server - context = SSL.Context(SSL.SSLv23_METHOD) - cert, privkey, cert_chain = CertStore.from_store( - os.path.expanduser("~/.mitmproxy"), "mitmproxy" - ).get_cert(b"example.com", (b"example.com",)) - context.use_privatekey(privkey) - context.use_certificate(cert.x509) - context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS) - - if self.state[server] == ConnectionState.ESTABLISHED: - alpn_for_client = self.tls[server].get_alpn_proto_negotiated() - - def alpn_select_callback(conn_, options): - if alpn_for_client in options: - return alpn_for_client - - context.set_alpn_select_callback(alpn_select_callback) - - self.tls[client] = SSL.Connection(context) - self.tls[client].set_accept_state() - - self.state[client] = ConnectionState.NEGOTIATING - yield from self.process(events.DataReceived( - client, bytes(self.recv_buffer[client]) - )) - self.recv_buffer[client] = bytearray() def tls_interact(self, conn: context.Connection): while True: @@ -347,3 +107,269 @@ class TLSLayer(layer.Layer): return else: yield commands.SendData(conn, data) + + def send( + self, + send_command: commands.SendData, + ) -> commands.TCommandGenerator: + tls_conn = self.tls[send_command.connection] + if send_command.connection.tls_established: + tls_conn.sendall(send_command.data) + yield from self.tls_interact(send_command.connection) + else: + buf = self.send_buffer.setdefault(tls_conn, bytearray()) + buf.extend(send_command.data) + + def negotiate(self, event: events.DataReceived) -> Generator[commands.Command, Any, bool]: + """ + Make sure to trigger processing if done! + """ + # bio_write errors for b"", so we need to check first if we actually received something. + tls_conn = self.tls[event.connection] + if event.data: + tls_conn.bio_write(event.data) + try: + tls_conn.do_handshake() + except SSL.WantReadError: + yield from self.tls_interact(event.connection) + return False + else: + event.connection.tls_established = True + event.connection.alpn = tls_conn.get_alpn_proto_negotiated() + print(f"TLS established: {event.connection}") + # TODO: Set all other connection attributes here + # there might already be data in the OpenSSL BIO, so we need to trigger its processing. + yield from self.relay(events.DataReceived(event.connection, b"")) + if tls_conn in self.send_buffer: + data_to_send = bytes(self.send_buffer.pop(tls_conn)) + yield from self.send(commands.SendData(event.connection, data_to_send)) + return True + + def relay(self, event: events.DataReceived): + tls_conn = self.tls[event.connection] + if event.data: + tls_conn.bio_write(event.data) + yield from self.tls_interact(event.connection) + + plaintext = bytearray() + while True: + try: + plaintext.extend(tls_conn.recv(65535)) + except (SSL.WantReadError, SSL.ZeroReturnError): + break + + if plaintext: + evt = events.DataReceived(event.connection, bytes(plaintext)) + # yield commands.Log(f"Plain{evt}") + yield from self.event_to_child(evt) + + def event_to_child(self, event: events.Event) -> commands.TCommandGenerator: + for command in self.child_layer.handle_event(event): + if isinstance(command, commands.SendData) and command.connection in self.tls: + yield from self.send(command) + else: + yield command + + +class ServerTLSLayer(_TLSLayer): + """ + This layer manages TLS on potentially multiple server connections. + """ + + def __init__(self, context: context.Context): + super().__init__(context) + self.child_layer = layer.NextLayer(context) + + @expect(events.Start) + def start(self, event: events.Start) -> commands.TCommandGenerator: + yield from self.child_layer.handle_event(event) + + server = self.context.server + if server.connected and server.tls: + yield from self._start_tls(server) + self._handle_event = self.process + + _handle_event = start + + def process(self, event: Union[events.DataReceived, events.ConnectionClosed]): + if isinstance(event, events.DataReceived) and event.connection in self.tls: + if not event.connection.tls_established: + yield from self.negotiate(event) + else: + yield from self.relay(event) + elif isinstance(event, events.OpenConnectionReply): + err = event.reply + conn = event.command.connection + if not err and conn.tls: + yield from self._start_tls(conn) + yield from self.event_to_child(event) + elif isinstance(event, events.ConnectionClosed): + yield from self.event_to_child(event) + self.send_buffer.pop( + self.tls.pop(event.connection, None), + None + ) + else: + yield from self.event_to_child(event) + + def _start_tls(self, server: context.Server): + ssl_context = SSL.Context(SSL.SSLv23_METHOD) + + if server.alpn_offers: + ssl_context.set_alpn_protos(server.alpn_offers) + + self.tls[server] = SSL.Connection(ssl_context) + + if server.sni: + if server.sni is True: + if self.context.client.sni: + server.sni = self.context.client.sni.encode("idna") + else: + server.sni = server.address[0].encode("idna") + self.tls[server].set_tlsext_host_name(server.sni) + self.tls[server].set_connect_state() + + yield from self.process(events.DataReceived(server, b"")) + + +class ClientTLSLayer(_TLSLayer): + """ + This layer establishes TLS on a single client connection. + + ┌─────┐ + │Start│ + └┬────┘ + ↓ + ┌────────────────────┐ + │Wait for ClientHello│ + └┬───────────────────┘ + │ Do we need server TLS info + │ to establish TLS with client? + │ ┌───────────────────┐ + ├─────→│Wait for Server TLS│ + │ yes └┬──────────────────┘ + │no │ + ↓ ↓ + ┌────────────────┐ + │Process messages│ + └────────────────┘ + + """ + recv_buffer: bytearray + + def __init__(self, context: context.Context): + super().__init__(context) + self.recv_buffer = bytearray() + self.child_layer = ServerTLSLayer(self.context) + + @expect(events.Start) + def state_start(self, _) -> commands.TCommandGenerator: + self.context.client.tls = True + self._handle_event = self.state_wait_for_clienthello + yield from () + + _handle_event = state_start + + @expect(events.DataReceived, events.ConnectionClosed) + def state_wait_for_clienthello(self, event: events.Event): + client = self.context.client + server = self.context.server + if isinstance(event, events.DataReceived) and event.connection == client: + self.recv_buffer.extend(event.data) + try: + client_hello = parse_client_hello(self.recv_buffer) + except ValueError as e: + raise NotImplementedError() from e # TODO + + if client_hello: + yield commands.Log(f"Client Hello: {client_hello}") + + client.sni = client_hello.sni + client.alpn_offers = client_hello.alpn_protocols + + client_tls_requires_server_connection = ( + self.context.server.tls and + self.context.options.upstream_cert and + ( + self.context.options.add_upstream_certs_to_client_chain or + client.alpn_offers or + not client.sni + ) + ) + + # What do we do with the client connection now? + if client_tls_requires_server_connection and not server.tls_established: + yield from self.start_server_tls() + self._handle_event = self.state_wait_for_server_tls + else: + yield from self.start_negotiate() + self._handle_event = self.state_process + else: + raise NotImplementedError(event) # TODO + + def state_wait_for_server_tls(self, event: events.Event): + yield from self.event_to_child(event) + # TODO: Handle case where TLS establishment fails. + # We still need a good way to signal this - one possibility would be by closing + # the connection? + if self.context.server.tls_established: + yield from self.start_negotiate() + self._handle_event = self.state_process + + def state_process(self, event: events.Event): + if isinstance(event, events.DataReceived) and event.connection == self.context.client: + if not event.connection.tls_established: + yield from self.negotiate(event) + else: + yield from self.relay(event) + else: + yield from self.event_to_child(event) + + def start_server_tls(self): + """ + We often need information from the upstream connection to establish TLS with the client. + For example, we need to check if the client does ALPN or not. + """ + if not self.context.server.connected: + err = yield commands.OpenConnection(self.context.server) + if err: + yield commands.Log( + "Cannot establish server connection, which is required to establish TLS with the client." + ) + + self.context.server.alpn_offers = [ + x for x in self.context.client.alpn_offers + if not (x.startswith(b"h2-") or x.startswith(b"spdy")) + ] + + yield from self.child_layer.handle_event(events.Start()) + + def start_negotiate(self): + if not self.child_layer: + yield from self.child_layer.handle_event(events.Start()) + + # FIXME: Do this properly + client = self.context.client + server = self.context.server + context = SSL.Context(SSL.SSLv23_METHOD) + cert, privkey, cert_chain = CertStore.from_store( + os.path.expanduser("~/.mitmproxy"), "mitmproxy" + ).get_cert(b"example.com", (b"example.com",)) + context.use_privatekey(privkey) + context.use_certificate(cert.x509) + context.set_cipher_list(tls.DEFAULT_CLIENT_CIPHERS) + + if server.alpn: + def alpn_select_callback(conn_, options): + if server.alpn in options: + return server.alpn + + context.set_alpn_select_callback(alpn_select_callback) + + self.tls[self.context.client] = SSL.Connection(context) + self.tls[self.context.client].set_accept_state() + + yield from self.state_process(events.DataReceived( + client, bytes(self.recv_buffer) + )) + self.recv_buffer = bytearray() diff --git a/mitmproxy/proxy2/server.py b/mitmproxy/proxy2/server.py index 7806502ba..24ad570bd 100644 --- a/mitmproxy/proxy2/server.py +++ b/mitmproxy/proxy2/server.py @@ -59,8 +59,13 @@ class ConnectionHandler(metaclass=abc.ABCMeta): # self._debug("transports closed!") async def close_connection(self, connection): - io = self.transports.pop(connection, None) - self.log(f"closing {connection}", "debug") + try: + io = self.transports.pop(connection) + except KeyError: + self.log(f"already closed: {connection}", "warn") + return + else: + self.log(f"closing {connection}", "debug") try: await io.w.drain() io.w.write_eof() @@ -109,10 +114,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta): print(message) def server_event(self, event: events.Event) -> None: - self.log(f">> {event}", "debug") layer_commands = self.layer.handle_event(event) for command in layer_commands: - self.log(f"<< {command}", "debug") if isinstance(command, commands.OpenConnection): asyncio.ensure_future( self.open_connection(command) @@ -155,13 +158,20 @@ if __name__ == "__main__": loop = asyncio.get_event_loop() opts = moptions.Options() - # opts.mode = "reverse:example.com" + opts.mode = "reverse:example.com" + # test client-tls-first scenario + # opts.upstream_cert = False + + layers.ClientTLSLayer.debug = "" + layers.ServerTLSLayer.debug = " " + layers.TCPLayer.debug = " " async def handle(reader, writer): layer_stack = [ - # layers.TLSLayer, - lambda c: layers.HTTPLayer(c, HTTPMode.regular), + layers.ClientTLSLayer, + #layers.ServerTLSLayer, layers.TCPLayer, + # lambda c: layers.HTTPLayer(c, HTTPMode.transparent), ] def next_layer(nl: layer.NextLayer): diff --git a/mitmproxy/proxy2/utils.py b/mitmproxy/proxy2/utils.py index 8c024d52f..d9b5718f4 100644 --- a/mitmproxy/proxy2/utils.py +++ b/mitmproxy/proxy2/utils.py @@ -57,7 +57,8 @@ def expect(*event_types): yield from f(self, event) else: raise TypeError( - "Invalid event type: Expected {}, got {}".format(event_types, event)) + "Invalid event type at {}: Expected {}, got {}.".format(f, event_types, event) + ) return wrapper