diff --git a/mitmproxy/proxy2/layers/http/http.py b/mitmproxy/proxy2/layers/http/http.py index c16d247af..b0eadcdd7 100644 --- a/mitmproxy/proxy2/layers/http/http.py +++ b/mitmproxy/proxy2/layers/http/http.py @@ -601,8 +601,11 @@ class HTTPLayer(Layer): self.connections[event.command.connection] = stream self.event_to_child(stream, event) elif isinstance(event, events.ConnectionEvent): - handler = self.connections[event.connection] - self.event_to_child(handler, event) + if event.connection == self.context.server and self.context.server not in self.connections: + pass + else: + handler = self.connections[event.connection] + self.event_to_child(handler, event) else: raise ValueError(f"Unexpected event: {event}") @@ -624,6 +627,15 @@ class HTTPLayer(Layer): else: raise ValueError(f"Unexpected event: {event}") + def make_stream(self) -> HttpStream: + ctx = self.context.fork() + + stream = HttpStream(ctx) + if self.debug: + stream.debug = self.debug + " " + self.event_to_child(stream, events.Start()) + return stream + def get_connection(self, event: GetHttpConnection): # Do we already have a connection we can re-use? for connection, handler in self.connections.items(): @@ -646,6 +658,7 @@ class HTTPLayer(Layer): return # Can we reuse context.server? can_reuse_context_connection = ( + self.context.server not in self.connections and self.context.server.connected and self.context.server.tls == event.tls ) @@ -661,15 +674,6 @@ class HTTPLayer(Layer): open_command.blocking = object() yield open_command - def make_stream(self) -> HttpStream: - ctx = self.context.fork() - - stream = HttpStream(ctx) - if self.debug: - stream.debug = self.debug + " " - self.event_to_child(stream, events.Start()) - return stream - def make_http_connection(self, connection: Server) -> None: if connection.tls and not connection.tls_established: new_command = EstablishServerTLS(connection) diff --git a/mitmproxy/proxy2/layers/tls.py b/mitmproxy/proxy2/layers/tls.py index 17fa55a83..e3ff31e46 100644 --- a/mitmproxy/proxy2/layers/tls.py +++ b/mitmproxy/proxy2/layers/tls.py @@ -1,6 +1,6 @@ import os import struct -from typing import Any, Generator, Iterator, Optional +from typing import Any, Dict, Generator, Iterator, Optional, Tuple from OpenSSL import SSL @@ -100,131 +100,154 @@ class EstablishServerTLSReply(events.CommandReply): class _TLSLayer(layer.Layer): - conn: Optional[context.Connection] = None - tls_conn: Optional[SSL.Connection] = None + tls: Dict[context.Connection, SSL.Connection] child_layer: layer.Layer + def __init__(self, context: context.Context): + super().__init__(context) + self.tls = {} + def __repr__(self): - if self.conn is None: + if not self.tls: state = "inactive" - elif self.conn.tls_established: - state = f"passthrough {self.conn.sni}, {self.conn.alpn}" else: - state = f"negotiating {self.conn.sni}, {self.conn.alpn}" + conn_states = [] + for conn in self.tls: + if conn.tls_established: + conn_states.append(f"passthrough {conn.sni} {conn.alpn}") + else: + conn_states.append(f"negotiating {conn.sni} {conn.alpn}") + state = ", ".join(conn_states) return f"{type(self).__name__}({state})" - def tls_interact(self): + def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator: while True: try: - data = self.tls_conn.bio_read(65535) + data = self.tls[conn].bio_read(65535) except SSL.WantReadError: # Okay, nothing more waiting to be sent. return else: - yield commands.SendData(self.conn, data) + yield commands.SendData(conn, data) - def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]: + def negotiate(self, conn: context.Connection, data: bytes) -> Generator[ + commands.Command, Any, Tuple[bool, Optional[str]]]: # bio_write errors for b"", so we need to check first if we actually received something. if data: - self.tls_conn.bio_write(data) + self.tls[conn].bio_write(data) try: - self.tls_conn.do_handshake() + self.tls[conn].do_handshake() except SSL.WantReadError: - yield from self.tls_interact() - return False - except SSL.ZeroReturnError: - raise # TODO: Figure out what to do when handshake fails. + yield from self.tls_interact(conn) + return False, None + except SSL.ZeroReturnError as e: + return False, repr(e) else: - self.conn.tls_established = True - self.conn.alpn = self.tls_conn.get_alpn_proto_negotiated() - yield commands.Log(f"TLS established: {self.conn}") - yield from self.receive(b"") + conn.tls_established = True + conn.alpn = self.tls[conn].get_alpn_proto_negotiated() + yield commands.Log(f"TLS established: {conn}") + yield from self.receive(conn, b"") # TODO: Set all other connection attributes here # there might already be data in the OpenSSL BIO, so we need to trigger its processing. - return True + return True, None - def receive(self, data: bytes): + def receive(self, conn: context.Connection, data: bytes): if data: - self.tls_conn.bio_write(data) - yield from self.tls_interact() + self.tls[conn].bio_write(data) + yield from self.tls_interact(conn) plaintext = bytearray() + close = False while True: try: - plaintext.extend(self.tls_conn.recv(65535)) - except (SSL.WantReadError, SSL.ZeroReturnError): + plaintext.extend(self.tls[conn].recv(65535)) + except SSL.WantReadError: + break + except SSL.ZeroReturnError: + close = True break if plaintext: - evt = events.DataReceived(self.conn, bytes(plaintext)) - yield from self.event_to_child(evt) + yield from self.event_to_child( + events.DataReceived(conn, bytes(plaintext)) + ) + if close: + conn.state &= ~context.ConnectionState.CAN_READ + yield commands.Log(f"TLS close_notify {conn=}") + yield from self.event_to_child( + events.ConnectionClosed(conn) + ) def event_to_child(self, event: events.Event) -> commands.TCommandGenerator: for command in self.child_layer.handle_event(event): - if isinstance(command, commands.SendData) and command.connection == self.conn: - self.tls_conn.sendall(command.data) - yield from self.tls_interact() + if isinstance(command, commands.SendData) and command.connection in self.tls: + self.tls[command.connection].sendall(command.data) + yield from self.tls_interact(command.connection) else: yield command def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: - if isinstance(event, events.DataReceived) and event.connection == self.conn: - if not self.conn.tls_established: - yield from self.negotiate(event.data) + if isinstance(event, events.DataReceived) and event.connection in self.tls: + if not event.connection.tls_established: + yield from self.negotiate(event.connection, event.data) else: - yield from self.receive(event.data) + yield from self.receive(event.connection, event.data) + elif ( + isinstance(event, events.ConnectionClosed) and + event.connection in self.tls and + self.tls[event.connection].get_shutdown() & SSL.RECEIVED_SHUTDOWN + ): + pass # We have already dispatched a ConnectionClosed to the child layer. else: yield from self.event_to_child(event) class ServerTLSLayer(_TLSLayer): """ - This layer manages TLS for a single server connection. + This layer manages TLS for potentially multiple server connections. """ - command_to_reply_to: Optional[EstablishServerTLS] = None + command_to_reply_to: Dict[context.Connection, EstablishServerTLS] def __init__(self, context: context.Context): super().__init__(context) + self.command_to_reply_to = {} self.child_layer = layer.NextLayer(self.context) - def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]: - done = yield from super().negotiate(data) - if done: - assert self.command_to_reply_to - yield from self.event_to_child(EstablishServerTLSReply(self.command_to_reply_to, None)) - self.command_to_reply_to = None - return done + def negotiate(self, conn: context.Connection, data: bytes) -> Generator[ + commands.Command, Any, Tuple[bool, Optional[str]]]: + done, err = yield from super().negotiate(conn, data) + if done or err: + cmd = self.command_to_reply_to.pop(conn) + yield from self.event_to_child(EstablishServerTLSReply(cmd, err)) + return done, err def event_to_child(self, event: events.Event) -> commands.TCommandGenerator: for command in super().event_to_child(event): if isinstance(command, EstablishServerTLS): - assert isinstance(command.connection, context.Server) - assert not self.command_to_reply_to - self.command_to_reply_to = command + self.command_to_reply_to[command.connection] = command yield from self.start_server_tls(command.connection) else: yield command - def start_server_tls(self, server: context.Server): + def start_server_tls(self, conn: context.Server): + assert conn not in self.tls + assert conn.connected + ssl_context = SSL.Context(SSL.SSLv23_METHOD) - if server.alpn_offers: - ssl_context.set_alpn_protos(server.alpn_offers) + if conn.alpn_offers: + ssl_context.set_alpn_protos(conn.alpn_offers) + self.tls[conn] = SSL.Connection(ssl_context) - assert not self.conn or not self.conn.connected - assert server.connected - self.conn = server - self.tls_conn = SSL.Connection(ssl_context) - - if server.sni: - if server.sni is True: + if conn.sni: + if conn.sni is True: if self.context.client.sni: - server.sni = self.context.client.sni + conn.sni = self.context.client.sni else: - server.sni = server.address[0] - self.tls_conn.set_tlsext_host_name(server.sni) - self.tls_conn.set_connect_state() + conn.sni = conn.address[0].encode() + self.tls[conn].set_tlsext_host_name(conn.sni) + self.tls[conn].set_connect_state() - yield from self.negotiate(b"") + yield from self.negotiate(conn, b"") class ClientTLSLayer(_TLSLayer): @@ -249,7 +272,6 @@ class ClientTLSLayer(_TLSLayer): def __init__(self, context: context.Context): assert isinstance(context.layers[-1], ServerTLSLayer) super().__init__(context) - self.conn = context.client self.recv_buffer = bytearray() self.child_layer = layer.NextLayer(self.context) @@ -269,7 +291,7 @@ class ClientTLSLayer(_TLSLayer): try: client_hello = parse_client_hello(self.recv_buffer) except ValueError as e: - raise NotImplementedError() from e # TODO + raise NotImplementedError from e # TODO if client_hello: yield commands.Log(f"Client Hello: {client_hello}") @@ -292,7 +314,8 @@ class ClientTLSLayer(_TLSLayer): if client_tls_requires_server_connection and not self.context.server.tls_established: err = yield from self.start_server_tls() if err: - raise NotImplementedError + yield commands.Log("Unable to establish TLS connection with server. " + "Trying to establish TLS with client anyway.") yield from self.start_client_tls() self._handle_event = super()._handle_event @@ -312,7 +335,7 @@ class ClientTLSLayer(_TLSLayer): err = yield commands.OpenConnection(server) if err: yield commands.Log( - "Cannot establish server connection, which is required to establish TLS with the client." + f"Cannot establish server connection: {err}" ) return err @@ -324,11 +347,11 @@ class ClientTLSLayer(_TLSLayer): err = yield EstablishServerTLS(server) if err: yield commands.Log( - "Cannot establish TLS with server, which is required to establish TLS with the client." + f"Cannot establish TLS with server: {err}" ) return err - def start_client_tls(self): + def start_client_tls(self) -> commands.TCommandGenerator: # FIXME: Do this properly. Also adjust error message in negotiate() client = self.context.client server = self.context.server @@ -358,17 +381,15 @@ class ClientTLSLayer(_TLSLayer): context.set_alpn_select_callback(alpn_select_callback) - self.tls_conn = SSL.Connection(context) - self.tls_conn.set_accept_state() + self.tls[client] = SSL.Connection(context) + self.tls[client].set_accept_state() - yield from self.negotiate(bytes(self.recv_buffer)) - self.recv_buffer = None + yield from self.negotiate(client, bytes(self.recv_buffer)) + self.recv_buffer.clear() - def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]: - try: - done = yield from super().negotiate(data) - return done - except SSL.ZeroReturnError: + def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]: + done, err = yield from super().negotiate(conn, data) + if err: yield commands.Log( f"Client TLS Handshake failed. " f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).", @@ -376,3 +397,4 @@ class ClientTLSLayer(_TLSLayer): # TODO: Also use other sources than SNI ) yield commands.CloseConnection(self.context.client) + return done diff --git a/mitmproxy/proxy2/server.py b/mitmproxy/proxy2/server.py index a6636de15..23aa97285 100644 --- a/mitmproxy/proxy2/server.py +++ b/mitmproxy/proxy2/server.py @@ -191,7 +191,7 @@ if __name__ == "__main__": opts = moptions.Options() opts.add_option( - "connection_strategy", str, "lazy", + "connection_strategy", str, "eager", "Determine when server connections should be established.", choices=("eager", "lazy") ) @@ -200,6 +200,7 @@ if __name__ == "__main__": async def handle(reader, writer): layer_stack = [ + lambda ctx: layers.ServerTLSLayer(ctx), lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular), lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx), lambda ctx: layers.ClientTLSLayer(ctx),