mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] tls: handle multiple servers
This commit is contained in:
parent
ffb3782618
commit
549e41ee40
@ -601,8 +601,11 @@ class HTTPLayer(Layer):
|
||||
self.connections[event.command.connection] = stream
|
||||
self.event_to_child(stream, event)
|
||||
elif isinstance(event, events.ConnectionEvent):
|
||||
handler = self.connections[event.connection]
|
||||
self.event_to_child(handler, event)
|
||||
if event.connection == self.context.server and self.context.server not in self.connections:
|
||||
pass
|
||||
else:
|
||||
handler = self.connections[event.connection]
|
||||
self.event_to_child(handler, event)
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
|
||||
@ -624,6 +627,15 @@ class HTTPLayer(Layer):
|
||||
else:
|
||||
raise ValueError(f"Unexpected event: {event}")
|
||||
|
||||
def make_stream(self) -> HttpStream:
|
||||
ctx = self.context.fork()
|
||||
|
||||
stream = HttpStream(ctx)
|
||||
if self.debug:
|
||||
stream.debug = self.debug + " "
|
||||
self.event_to_child(stream, events.Start())
|
||||
return stream
|
||||
|
||||
def get_connection(self, event: GetHttpConnection):
|
||||
# Do we already have a connection we can re-use?
|
||||
for connection, handler in self.connections.items():
|
||||
@ -646,6 +658,7 @@ class HTTPLayer(Layer):
|
||||
return
|
||||
# Can we reuse context.server?
|
||||
can_reuse_context_connection = (
|
||||
self.context.server not in self.connections and
|
||||
self.context.server.connected and
|
||||
self.context.server.tls == event.tls
|
||||
)
|
||||
@ -661,15 +674,6 @@ class HTTPLayer(Layer):
|
||||
open_command.blocking = object()
|
||||
yield open_command
|
||||
|
||||
def make_stream(self) -> HttpStream:
|
||||
ctx = self.context.fork()
|
||||
|
||||
stream = HttpStream(ctx)
|
||||
if self.debug:
|
||||
stream.debug = self.debug + " "
|
||||
self.event_to_child(stream, events.Start())
|
||||
return stream
|
||||
|
||||
def make_http_connection(self, connection: Server) -> None:
|
||||
if connection.tls and not connection.tls_established:
|
||||
new_command = EstablishServerTLS(connection)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import struct
|
||||
from typing import Any, Generator, Iterator, Optional
|
||||
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
@ -100,131 +100,154 @@ class EstablishServerTLSReply(events.CommandReply):
|
||||
|
||||
|
||||
class _TLSLayer(layer.Layer):
|
||||
conn: Optional[context.Connection] = None
|
||||
tls_conn: Optional[SSL.Connection] = None
|
||||
tls: Dict[context.Connection, SSL.Connection]
|
||||
child_layer: layer.Layer
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context)
|
||||
self.tls = {}
|
||||
|
||||
def __repr__(self):
|
||||
if self.conn is None:
|
||||
if not self.tls:
|
||||
state = "inactive"
|
||||
elif self.conn.tls_established:
|
||||
state = f"passthrough {self.conn.sni}, {self.conn.alpn}"
|
||||
else:
|
||||
state = f"negotiating {self.conn.sni}, {self.conn.alpn}"
|
||||
conn_states = []
|
||||
for conn in self.tls:
|
||||
if conn.tls_established:
|
||||
conn_states.append(f"passthrough {conn.sni} {conn.alpn}")
|
||||
else:
|
||||
conn_states.append(f"negotiating {conn.sni} {conn.alpn}")
|
||||
state = ", ".join(conn_states)
|
||||
return f"{type(self).__name__}({state})"
|
||||
|
||||
def tls_interact(self):
|
||||
def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator:
|
||||
while True:
|
||||
try:
|
||||
data = self.tls_conn.bio_read(65535)
|
||||
data = self.tls[conn].bio_read(65535)
|
||||
except SSL.WantReadError:
|
||||
# Okay, nothing more waiting to be sent.
|
||||
return
|
||||
else:
|
||||
yield commands.SendData(self.conn, data)
|
||||
yield commands.SendData(conn, data)
|
||||
|
||||
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]:
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
|
||||
commands.Command, Any, Tuple[bool, Optional[str]]]:
|
||||
# bio_write errors for b"", so we need to check first if we actually received something.
|
||||
if data:
|
||||
self.tls_conn.bio_write(data)
|
||||
self.tls[conn].bio_write(data)
|
||||
try:
|
||||
self.tls_conn.do_handshake()
|
||||
self.tls[conn].do_handshake()
|
||||
except SSL.WantReadError:
|
||||
yield from self.tls_interact()
|
||||
return False
|
||||
except SSL.ZeroReturnError:
|
||||
raise # TODO: Figure out what to do when handshake fails.
|
||||
yield from self.tls_interact(conn)
|
||||
return False, None
|
||||
except SSL.ZeroReturnError as e:
|
||||
return False, repr(e)
|
||||
else:
|
||||
self.conn.tls_established = True
|
||||
self.conn.alpn = self.tls_conn.get_alpn_proto_negotiated()
|
||||
yield commands.Log(f"TLS established: {self.conn}")
|
||||
yield from self.receive(b"")
|
||||
conn.tls_established = True
|
||||
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
||||
yield commands.Log(f"TLS established: {conn}")
|
||||
yield from self.receive(conn, b"")
|
||||
# TODO: Set all other connection attributes here
|
||||
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
|
||||
return True
|
||||
return True, None
|
||||
|
||||
def receive(self, data: bytes):
|
||||
def receive(self, conn: context.Connection, data: bytes):
|
||||
if data:
|
||||
self.tls_conn.bio_write(data)
|
||||
yield from self.tls_interact()
|
||||
self.tls[conn].bio_write(data)
|
||||
yield from self.tls_interact(conn)
|
||||
|
||||
plaintext = bytearray()
|
||||
close = False
|
||||
while True:
|
||||
try:
|
||||
plaintext.extend(self.tls_conn.recv(65535))
|
||||
except (SSL.WantReadError, SSL.ZeroReturnError):
|
||||
plaintext.extend(self.tls[conn].recv(65535))
|
||||
except SSL.WantReadError:
|
||||
break
|
||||
except SSL.ZeroReturnError:
|
||||
close = True
|
||||
break
|
||||
|
||||
if plaintext:
|
||||
evt = events.DataReceived(self.conn, bytes(plaintext))
|
||||
yield from self.event_to_child(evt)
|
||||
yield from self.event_to_child(
|
||||
events.DataReceived(conn, bytes(plaintext))
|
||||
)
|
||||
if close:
|
||||
conn.state &= ~context.ConnectionState.CAN_READ
|
||||
yield commands.Log(f"TLS close_notify {conn=}")
|
||||
yield from self.event_to_child(
|
||||
events.ConnectionClosed(conn)
|
||||
)
|
||||
|
||||
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
for command in self.child_layer.handle_event(event):
|
||||
if isinstance(command, commands.SendData) and command.connection == self.conn:
|
||||
self.tls_conn.sendall(command.data)
|
||||
yield from self.tls_interact()
|
||||
if isinstance(command, commands.SendData) and command.connection in self.tls:
|
||||
self.tls[command.connection].sendall(command.data)
|
||||
yield from self.tls_interact(command.connection)
|
||||
else:
|
||||
yield command
|
||||
|
||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
if isinstance(event, events.DataReceived) and event.connection == self.conn:
|
||||
if not self.conn.tls_established:
|
||||
yield from self.negotiate(event.data)
|
||||
if isinstance(event, events.DataReceived) and event.connection in self.tls:
|
||||
if not event.connection.tls_established:
|
||||
yield from self.negotiate(event.connection, event.data)
|
||||
else:
|
||||
yield from self.receive(event.data)
|
||||
yield from self.receive(event.connection, event.data)
|
||||
elif (
|
||||
isinstance(event, events.ConnectionClosed) and
|
||||
event.connection in self.tls and
|
||||
self.tls[event.connection].get_shutdown() & SSL.RECEIVED_SHUTDOWN
|
||||
):
|
||||
pass # We have already dispatched a ConnectionClosed to the child layer.
|
||||
else:
|
||||
yield from self.event_to_child(event)
|
||||
|
||||
|
||||
class ServerTLSLayer(_TLSLayer):
|
||||
"""
|
||||
This layer manages TLS for a single server connection.
|
||||
This layer manages TLS for potentially multiple server connections.
|
||||
"""
|
||||
command_to_reply_to: Optional[EstablishServerTLS] = None
|
||||
command_to_reply_to: Dict[context.Connection, EstablishServerTLS]
|
||||
|
||||
def __init__(self, context: context.Context):
|
||||
super().__init__(context)
|
||||
self.command_to_reply_to = {}
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]:
|
||||
done = yield from super().negotiate(data)
|
||||
if done:
|
||||
assert self.command_to_reply_to
|
||||
yield from self.event_to_child(EstablishServerTLSReply(self.command_to_reply_to, None))
|
||||
self.command_to_reply_to = None
|
||||
return done
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
|
||||
commands.Command, Any, Tuple[bool, Optional[str]]]:
|
||||
done, err = yield from super().negotiate(conn, data)
|
||||
if done or err:
|
||||
cmd = self.command_to_reply_to.pop(conn)
|
||||
yield from self.event_to_child(EstablishServerTLSReply(cmd, err))
|
||||
return done, err
|
||||
|
||||
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
||||
for command in super().event_to_child(event):
|
||||
if isinstance(command, EstablishServerTLS):
|
||||
assert isinstance(command.connection, context.Server)
|
||||
assert not self.command_to_reply_to
|
||||
self.command_to_reply_to = command
|
||||
self.command_to_reply_to[command.connection] = command
|
||||
yield from self.start_server_tls(command.connection)
|
||||
else:
|
||||
yield command
|
||||
|
||||
def start_server_tls(self, server: context.Server):
|
||||
def start_server_tls(self, conn: context.Server):
|
||||
assert conn not in self.tls
|
||||
assert conn.connected
|
||||
|
||||
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
if server.alpn_offers:
|
||||
ssl_context.set_alpn_protos(server.alpn_offers)
|
||||
if conn.alpn_offers:
|
||||
ssl_context.set_alpn_protos(conn.alpn_offers)
|
||||
self.tls[conn] = SSL.Connection(ssl_context)
|
||||
|
||||
assert not self.conn or not self.conn.connected
|
||||
assert server.connected
|
||||
self.conn = server
|
||||
self.tls_conn = SSL.Connection(ssl_context)
|
||||
|
||||
if server.sni:
|
||||
if server.sni is True:
|
||||
if conn.sni:
|
||||
if conn.sni is True:
|
||||
if self.context.client.sni:
|
||||
server.sni = self.context.client.sni
|
||||
conn.sni = self.context.client.sni
|
||||
else:
|
||||
server.sni = server.address[0]
|
||||
self.tls_conn.set_tlsext_host_name(server.sni)
|
||||
self.tls_conn.set_connect_state()
|
||||
conn.sni = conn.address[0].encode()
|
||||
self.tls[conn].set_tlsext_host_name(conn.sni)
|
||||
self.tls[conn].set_connect_state()
|
||||
|
||||
yield from self.negotiate(b"")
|
||||
yield from self.negotiate(conn, b"")
|
||||
|
||||
|
||||
class ClientTLSLayer(_TLSLayer):
|
||||
@ -249,7 +272,6 @@ class ClientTLSLayer(_TLSLayer):
|
||||
def __init__(self, context: context.Context):
|
||||
assert isinstance(context.layers[-1], ServerTLSLayer)
|
||||
super().__init__(context)
|
||||
self.conn = context.client
|
||||
self.recv_buffer = bytearray()
|
||||
self.child_layer = layer.NextLayer(self.context)
|
||||
|
||||
@ -269,7 +291,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
try:
|
||||
client_hello = parse_client_hello(self.recv_buffer)
|
||||
except ValueError as e:
|
||||
raise NotImplementedError() from e # TODO
|
||||
raise NotImplementedError from e # TODO
|
||||
|
||||
if client_hello:
|
||||
yield commands.Log(f"Client Hello: {client_hello}")
|
||||
@ -292,7 +314,8 @@ class ClientTLSLayer(_TLSLayer):
|
||||
if client_tls_requires_server_connection and not self.context.server.tls_established:
|
||||
err = yield from self.start_server_tls()
|
||||
if err:
|
||||
raise NotImplementedError
|
||||
yield commands.Log("Unable to establish TLS connection with server. "
|
||||
"Trying to establish TLS with client anyway.")
|
||||
|
||||
yield from self.start_client_tls()
|
||||
self._handle_event = super()._handle_event
|
||||
@ -312,7 +335,7 @@ class ClientTLSLayer(_TLSLayer):
|
||||
err = yield commands.OpenConnection(server)
|
||||
if err:
|
||||
yield commands.Log(
|
||||
"Cannot establish server connection, which is required to establish TLS with the client."
|
||||
f"Cannot establish server connection: {err}"
|
||||
)
|
||||
return err
|
||||
|
||||
@ -324,11 +347,11 @@ class ClientTLSLayer(_TLSLayer):
|
||||
err = yield EstablishServerTLS(server)
|
||||
if err:
|
||||
yield commands.Log(
|
||||
"Cannot establish TLS with server, which is required to establish TLS with the client."
|
||||
f"Cannot establish TLS with server: {err}"
|
||||
)
|
||||
return err
|
||||
|
||||
def start_client_tls(self):
|
||||
def start_client_tls(self) -> commands.TCommandGenerator:
|
||||
# FIXME: Do this properly. Also adjust error message in negotiate()
|
||||
client = self.context.client
|
||||
server = self.context.server
|
||||
@ -358,17 +381,15 @@ class ClientTLSLayer(_TLSLayer):
|
||||
|
||||
context.set_alpn_select_callback(alpn_select_callback)
|
||||
|
||||
self.tls_conn = SSL.Connection(context)
|
||||
self.tls_conn.set_accept_state()
|
||||
self.tls[client] = SSL.Connection(context)
|
||||
self.tls[client].set_accept_state()
|
||||
|
||||
yield from self.negotiate(bytes(self.recv_buffer))
|
||||
self.recv_buffer = None
|
||||
yield from self.negotiate(client, bytes(self.recv_buffer))
|
||||
self.recv_buffer.clear()
|
||||
|
||||
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]:
|
||||
try:
|
||||
done = yield from super().negotiate(data)
|
||||
return done
|
||||
except SSL.ZeroReturnError:
|
||||
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
|
||||
done, err = yield from super().negotiate(conn, data)
|
||||
if err:
|
||||
yield commands.Log(
|
||||
f"Client TLS Handshake failed. "
|
||||
f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).",
|
||||
@ -376,3 +397,4 @@ class ClientTLSLayer(_TLSLayer):
|
||||
# TODO: Also use other sources than SNI
|
||||
)
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
return done
|
||||
|
@ -191,7 +191,7 @@ if __name__ == "__main__":
|
||||
|
||||
opts = moptions.Options()
|
||||
opts.add_option(
|
||||
"connection_strategy", str, "lazy",
|
||||
"connection_strategy", str, "eager",
|
||||
"Determine when server connections should be established.",
|
||||
choices=("eager", "lazy")
|
||||
)
|
||||
@ -200,6 +200,7 @@ if __name__ == "__main__":
|
||||
|
||||
async def handle(reader, writer):
|
||||
layer_stack = [
|
||||
lambda ctx: layers.ServerTLSLayer(ctx),
|
||||
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular),
|
||||
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
|
||||
lambda ctx: layers.ClientTLSLayer(ctx),
|
||||
|
Loading…
Reference in New Issue
Block a user