diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index 3bb1fa18b..c059b0dd1 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -1,10 +1,20 @@ from mitmproxy import ctx from mitmproxy.net import server_spec +from mitmproxy.proxy.config import HostMatcher from mitmproxy.proxy.protocol import is_tls_record_magic from mitmproxy.proxy2 import layer, layers class NextLayer: + check_tcp: HostMatcher + + def __init__(self): + self.check_tcp = HostMatcher() + + def configure(self, updated): + if "tcp_hosts" in updated: + self.check_tcp = HostMatcher(ctx.options.tcp_hosts) + def next_layer(self, nextlayer: layer.NextLayer): top_layer = nextlayer.context.layers[-1] data_client = nextlayer.data_client() @@ -15,27 +25,46 @@ class NextLayer: client_tls = is_tls_record_magic(data_client) # 1. check for --ignore + # TODO # 2. Always insert a TLS layer as second layer, even if there's neither client nor server # tls. An addon may upgrade from http to https, in which case we need a TLS layer. if isinstance(top_layer, layers.modes.ReverseProxy): - if client_tls: - nextlayer.layer = layers.TLSLayer( - nextlayer.context, - client_tls, - server_spec.parse_with_mode(ctx.options.mode)[1].scheme == "https" - ) - else: - # FIXME: TLSLayer doesn't support non-TLS yet, so remove this here once that's in. - nextlayer.layer = layers.HTTPLayer( - nextlayer.context - ) + nextlayer.context.client.tls = client_tls + nextlayer.context.server.tls = ( + server_spec.parse_with_mode(ctx.options.mode)[1].scheme == "https" + ) + nextlayer.layer = layers.TLSLayer(nextlayer.context) + return # TODO: Other top layers - pass # 3. In Http Proxy mode and Upstream Proxy mode, the next layer is fixed. + # TODO + # 4. Check for other TLS cases (e.g. after CONNECT). + if client_tls: + nextlayer.context.client.tls = True + nextlayer.context.server.tls = True + nextlayer.layer = layers.TLSLayer(nextlayer.context) + return + # 5. Check for --tcp + if self.check_tcp(nextlayer.context.server.address): + nextlayer.layer = layers.TCPLayer(nextlayer.context) + return + # 6. Check for TLS ALPN (HTTP1/HTTP2) + if isinstance(top_layer, layers.TLSLayer): + alpn = nextlayer.context.client.alpn + if alpn == b'http/1.1': + nextlayer.layer = layers.HTTPLayer(nextlayer.context) + return + # TODO + + pass # 7. Check for raw tcp mode + # TODO + # 8. Assume HTTP1 by default + nextlayer.layer = layers.HTTPLayer(nextlayer.context) + return diff --git a/mitmproxy/proxy2/context.py b/mitmproxy/proxy2/context.py index 14620a6e9..838ce8f54 100644 --- a/mitmproxy/proxy2/context.py +++ b/mitmproxy/proxy2/context.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Union from mitmproxy.options import Options @@ -8,22 +8,28 @@ class Connection: Connections exposed to the layers only contain metadata, no socket objects. """ address: tuple - connected: bool + connected: bool = False + tls: bool = False + alpn: Optional[bytes] = None def __repr__(self): return f"{type(self).__name__}({repr(self.__dict__)})" class Client(Connection): + sni: Optional[bytes] = None + def __init__(self, address): self.address = address self.connected = True class Server(Connection): + sni: Union[bytes, bool] = True + """True: client SNI, False: no SNI, bytes: custom value""" + def __init__(self, address): self.address = address - self.connected = False class Context: diff --git a/mitmproxy/proxy2/layers/tls.py b/mitmproxy/proxy2/layers/tls.py new file mode 100644 index 000000000..8355fe6d3 --- /dev/null +++ b/mitmproxy/proxy2/layers/tls.py @@ -0,0 +1,298 @@ +import os +from enum import Enum +from typing import MutableMapping, Generator, Optional + +from OpenSSL import SSL + +from mitmproxy import exceptions +from mitmproxy.certs import CertStore +from mitmproxy.proxy.protocol import TlsClientHello +from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS +from mitmproxy.proxy2 import context +from mitmproxy.proxy2 import layer, commands, events +from mitmproxy.proxy2.utils import expect + + +class ConnectionState(Enum): + NO_TLS = 1 + WAIT_FOR_CLIENTHELLO = 2 + WAIT_FOR_SERVER_TLS = 3 + WAIT_FOR_OPENCONNECTION = 4 + NEGOTIATING = 5 + ESTABLISHED = 6 + + +class TLSLayer(layer.Layer): + """ + The TLS layer manages both client- and server-side TLS connection state. + This unfortunately is quite complex as the client handshake may depend on the server + handshake and vice versa: We need the client's SNI and ALPN to connect upstream, + and we need the server's ALPN choice to complete our client TLS handshake. + On top, we may have configurations where TLS is only added on one end, + and we also may have OpenConnection events which change the server's TLS configuration. + + + The following state machine shows the possible states for client and server connection: + + Legend: + /: NO_TLS + WCH: WAIT_FOR_CLIENTHELLO + WST: WAIT_FOR_SERVER_TLS + WOC: WAIT_FOR_OPENCONNECTION + N: NEGOTIATING + E: ESTABLISHED + + +------------+ +---+ +------------+ +---+<--+ + |Client State|--------->| / | |Server State|--------->+ / | | + +------------+ no tls +---+ +------------+ no tls +---+ | + | |server tls | | + |client tls | OpenConn(TLS)| |OpenConn(no TLS) + v v v | + +------------------------------+ +-------------------->+---+ | + | no server tls | | no client tls | N | | + | | | +->+---+-->+ + |server tls v |client tls | | | + v +---->+---+ | | | | + +---+ | | N | v | v | + |WCH| | +->+---+ +---+ | +---+ | + +---+ | | | |WCH| | | E |-->+ + | | | v +---+ | +---+ | + |ClientHello arrives | | +---+ | | | + | | | | E | |ClientHello +<----+ | + +----------------------+ | +---+ |arrives | | | + | no server info needed | v | | | + | | +------------------+ | | + |server info needed | | already connected | | + v | | or server info needed | | + +---+ | | | | + |WST|-----------------------+ |not needed | | + +---+ server tls established v (TLS)| |(no TLS) + (or errored) +---+ | | + |WOC+--------------------->+----+ + +---+ OpenConn + """ + tls: MutableMapping[context.Connection, SSL.Connection] + state: MutableMapping[context.Connection, ConnectionState] + recv_buffer: MutableMapping[context.Connection, bytearray] + client_hello: Optional[TlsClientHello] + + child_layer: layer.Layer + + def __init__(self, context: context.Context): + super().__init__(context) + self.tls = {} + self.state = {} + self.recv_buffer = { + context.client: bytearray(), + context.server: bytearray() + } + self.client_hello = None + + self.child_layer = layer.NextLayer(context) + + @expect(events.Start) + def start(self, _) -> commands.TCommandGenerator: + client = self.context.client + server = self.context.server + + if client.tls and server.tls: + self.state[client] = ConnectionState.WAIT_FOR_CLIENTHELLO + self.state[server] = ConnectionState.WAIT_FOR_CLIENTHELLO + elif client.tls: + yield from self.start_client_tls() + self.state[server] = ConnectionState.NO_TLS + elif server.tls and server.connected: + self.state[client] = ConnectionState.NO_TLS + yield from self.start_server_tls() + else: + self.state[client] = ConnectionState.NO_TLS + self.state[server] = ConnectionState.NO_TLS + + yield from self.child_layer.handle_event(events.Start()) + self._handle_event = self.process + + _handle_event = start + + def send(self, send_command: commands.SendData) -> commands.TCommandGenerator: + if self.state[send_command.connection] == ConnectionState.NO_TLS: + yield send_command + else: + self.tls[send_command.connection].sendall(send_command.data) + yield from self.tls_interact(send_command.connection) + + def event_to_child(self, event: events.Event) -> commands.TCommandGenerator: + for command in self.child_layer.handle_event(event): + if isinstance(command, commands.SendData): + yield commands.Log(f"Plain{command}") + yield from self.send(command) + elif isinstance(command, commands.OpenConnection): + raise NotImplementedError() + else: + yield command + + def recv(self, recv_event: events.DataReceived) -> Generator[commands.Command, None, bytes]: + if self.state[recv_event.connection] == ConnectionState.NO_TLS: + return recv_event.data + else: + if recv_event.data: + self.tls[recv_event.connection].bio_write(recv_event.data) + yield from self.tls_interact(recv_event.connection) + + recvd = bytearray() + while True: + try: + recvd.extend(self.tls[recv_event.connection].recv(65535)) + except (SSL.WantReadError, SSL.ZeroReturnError): + return bytes(recvd) + + def parse_client_hello(self): + # Check if ClientHello is complete + # FIXME: temporary mock + class Rfile: + def __init__(self, data): + self.data = data + + def peek(self, n): + return self.data[:n] + + class CCon: + def __init__(self, data): + self.rfile = Rfile(data) + + try: + self.client_hello = TlsClientHello.from_client_conn( + CCon( + self.recv_buffer[self.context.client] + ) + ) + except exceptions.TlsProtocolException: + return False + else: + return True + + def process(self, event: events.Event): + if isinstance(event, events.DataReceived): + state = self.state[event.connection] + + if state is ConnectionState.WAIT_FOR_CLIENTHELLO: + yield from self.process_wait_for_clienthello(event) + elif state is ConnectionState.WAIT_FOR_SERVER_TLS: + self.recv_buffer[self.context.client].extend(event.data) + elif state is ConnectionState.NEGOTIATING: + yield from self.process_negotiate(event) + elif state is ConnectionState.NO_TLS: + yield from self.process_relay(event) + elif state is ConnectionState.ESTABLISHED: + yield from self.process_relay(event) + else: + raise RuntimeError("Unexpected state") + else: + yield from self.event_to_child(event) + + def process_wait_for_clienthello(self, event: events.DataReceived): + client = self.context.client + server = self.context.server + # We are not ready to process this yet. + self.recv_buffer[event.connection].extend(event.data) + + if event.connection == client and self.parse_client_hello(): + self._debug("SNI", self.client_hello.sni) + if self.context.server.sni is True: + self.context.server.sni = self.client_hello.sni.encode("idna") + + client_tls_requires_server_connection = ( + self.context.server.tls and + self.context.options.upstream_cert and + ( + self.context.options.add_upstream_certs_to_client_chain or + self.client_hello.alpn_protocols or + not self.client_hello.sni + ) + ) + + if client_tls_requires_server_connection and not self.context.server.connected: + yield commands.OpenConnection(self.context.server) + + if not self.context.server.connected: + # We are only in the WAIT_FOR_CLIENTHELLO branch if we have two TLS conns. + assert self.context.server.tls + self.state[server] = ConnectionState.WAIT_FOR_OPENCONNECTION + else: + yield from self.start_server_tls() + if client_tls_requires_server_connection: + self.state[client] = ConnectionState.WAIT_FOR_SERVER_TLS + else: + yield from self.start_client_tls() + + def process_negotiate(self, event: events.DataReceived): + # bio_write errors for b"", so we need to check first if we actually received something. + if event.data: + self.tls[event.connection].bio_write(event.data) + try: + self.tls[event.connection].do_handshake() + except SSL.WantReadError: + yield from self.tls_interact(event.connection) + else: + self.state[event.connection] = ConnectionState.ESTABLISHED + event.connection.sni = self.tls[event.connection].get_servername() + event.connection.alpn = self.tls[event.connection].get_alpn_proto_negotiated() + + # there might already be data in the OpenSSL BIO, so we need to trigger its processing. + yield from self.process(events.DataReceived(event.connection, b"")) + + if self.state[self.context.client] == ConnectionState.WAIT_FOR_SERVER_TLS: + assert event.connection == self.context.server + yield from self.start_client_tls() + + def process_relay(self, event: events.DataReceived): + plaintext = yield from self.recv(event) + if plaintext: + evt = events.DataReceived(event.connection, plaintext) + yield commands.Log(f"Plain{evt}") + yield from self.event_to_child(evt) + + def start_server_tls(self): + server = self.context.server + + ssl_context = SSL.Context(SSL.SSLv23_METHOD) + self.tls[server] = SSL.Connection(ssl_context) + + if server.sni: + self.tls[server].set_tlsext_host_name(server.sni) + # FIXME: Handle ALPN + self.tls[server].set_connect_state() + + self.state[server] = ConnectionState.NEGOTIATING + yield from self.process(events.DataReceived( + server, bytes(self.recv_buffer[server]) + )) + self.recv_buffer[server] = bytearray() + + def start_client_tls(self): + # FIXME + client = self.context.client + context = SSL.Context(SSL.SSLv23_METHOD) + cert, privkey, cert_chain = CertStore.from_store( + os.path.expanduser("~/.mitmproxy"), "mitmproxy" + ).get_cert(b"example.com", (b"example.com",)) + context.use_privatekey(privkey) + context.use_certificate(cert.x509) + context.set_cipher_list(DEFAULT_CLIENT_CIPHERS) + self.tls[client] = SSL.Connection(context) + self.tls[client].set_accept_state() + + self.state[client] = ConnectionState.NEGOTIATING + yield from self.process(events.DataReceived( + client, bytes(self.recv_buffer[client]) + )) + self.recv_buffer[client] = bytearray() + + def tls_interact(self, conn: context.Connection): + while True: + try: + data = self.tls[conn].bio_read(65535) + except SSL.WantReadError: + # Okay, nothing more waiting to be sent. + return + else: + yield commands.SendData(conn, data) diff --git a/mitmproxy/proxy2/layers/tls_old.py b/mitmproxy/proxy2/layers/tls_old.py deleted file mode 100644 index baf6db2c4..000000000 --- a/mitmproxy/proxy2/layers/tls_old.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -TLS man-in-the-middle layer. -""" -# We may want to split this up into client (only once) and server (for every server) layer. -import os -from typing import MutableMapping -from warnings import warn - -from OpenSSL import SSL - -from mitmproxy.certs import CertStore -from mitmproxy.proxy.protocol.tls import DEFAULT_CLIENT_CIPHERS -from mitmproxy.proxy2 import events, commands, layer -from mitmproxy.proxy2.context import Context, Connection -from mitmproxy.proxy2.utils import expect - - -class TLSLayer(layer.Layer): - client_tls: bool # FIXME: not yet used. - server_tls: bool - child_layer: layer.Layer = None - tls: MutableMapping[Connection, SSL.Connection] - - def __init__(self, context: Context, client_tls: bool, server_tls: bool): - super().__init__(context) - self.state = self.start - self.client_tls = client_tls - self.server_tls = server_tls - self.tls = {} - - def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: - yield from self.state(event) - - @expect(events.Start) - def start(self, _) -> commands.TCommandGenerator: - yield from self.start_client_tls() - if not self.context.server.connected: - # TODO: This should be lazy. - yield commands.OpenConnection(self.context.server) - yield from self.start_server_tls() - self.state = self.establish_tls - - def start_client_tls(self): - conn = self.context.client - context = SSL.Context(SSL.SSLv23_METHOD) - cert, privkey, cert_chain = CertStore.from_store( - os.path.expanduser("~/.mitmproxy"), "mitmproxy" - ).get_cert(b"example.com", (b"example.com",)) - context.use_privatekey(privkey) - context.use_certificate(cert.x509) - context.set_cipher_list(DEFAULT_CLIENT_CIPHERS) - self.tls[conn] = SSL.Connection(context) - self.tls[conn].set_accept_state() - try: - self.tls[conn].do_handshake() - except SSL.WantReadError: - pass - yield from self.tls_interact(conn) - - def start_server_tls(self): - conn = self.context.server - self.tls[conn] = SSL.Connection(SSL.Context(SSL.SSLv23_METHOD)) - self.tls[conn].set_connect_state() - try: - self.tls[conn].do_handshake() - except SSL.WantReadError: - pass - yield from self.tls_interact(conn) - - def tls_interact(self, conn: Connection): - while True: - try: - data = self.tls[conn].bio_read(4096) - except SSL.WantReadError: - # Okay, nothing more waiting to be sent. - return - else: - yield commands.SendData(conn, data) - - @expect(events.ConnectionClosed, events.DataReceived) - def establish_tls(self, event: events.Event) -> commands.TCommandGenerator: - if isinstance(event, events.DataReceived): - self.tls[event.connection].bio_write(event.data) - try: - self.tls[event.connection].do_handshake() - except SSL.WantReadError: - pass - yield from self.tls_interact(event.connection) - - both_handshakes_done = ( - self.tls[self.context.client].get_peer_finished() and - self.context.server in self.tls and self.tls[ - self.context.server].get_peer_finished() - ) - - if both_handshakes_done: - print("both handshakes done") - self.child_layer = layer.NextLayer(self.context) - yield from self.child_layer.handle_event(events.Start()) - self.state = self.relay_messages - yield from self.state(events.DataReceived(self.context.server, b"")) - yield from self.state(events.DataReceived(self.context.client, b"")) - - elif isinstance(event, events.ConnectionClosed): - warn("unimplemented: tls.establish_tls:close") - - @expect(events.ConnectionClosed, events.DataReceived) - def relay_messages(self, event: events.Event) -> commands.TCommandGenerator: - if isinstance(event, events.DataReceived): - if event.data: - self.tls[event.connection].bio_write(event.data) - yield from self.tls_interact(event.connection) - - while True: - try: - plaintext = self.tls[event.connection].recv(4096) - except (SSL.WantReadError, SSL.ZeroReturnError): - return - - event_for_child = events.DataReceived(self.context.server, plaintext) - - for event_from_child in self.child_layer.handle_event(event_for_child): - if isinstance(event_from_child, commands.SendData): - self.tls[event_from_child.connection].sendall(event_from_child.data) - yield from self.tls_interact(event_from_child.connection) - else: - yield event_from_child - elif isinstance(event, events.ConnectionClosed): - warn("unimplemented: tls.relay_messages:close")