[sans-io] tls: handle multiple servers

This commit is contained in:
Maximilian Hils 2019-11-06 18:59:53 +01:00
parent ffb3782618
commit 549e41ee40
3 changed files with 117 additions and 90 deletions

View File

@ -601,6 +601,9 @@ class HTTPLayer(Layer):
self.connections[event.command.connection] = stream self.connections[event.command.connection] = stream
self.event_to_child(stream, event) self.event_to_child(stream, event)
elif isinstance(event, events.ConnectionEvent): elif isinstance(event, events.ConnectionEvent):
if event.connection == self.context.server and self.context.server not in self.connections:
pass
else:
handler = self.connections[event.connection] handler = self.connections[event.connection]
self.event_to_child(handler, event) self.event_to_child(handler, event)
else: else:
@ -624,6 +627,15 @@ class HTTPLayer(Layer):
else: else:
raise ValueError(f"Unexpected event: {event}") raise ValueError(f"Unexpected event: {event}")
def make_stream(self) -> HttpStream:
ctx = self.context.fork()
stream = HttpStream(ctx)
if self.debug:
stream.debug = self.debug + " "
self.event_to_child(stream, events.Start())
return stream
def get_connection(self, event: GetHttpConnection): def get_connection(self, event: GetHttpConnection):
# Do we already have a connection we can re-use? # Do we already have a connection we can re-use?
for connection, handler in self.connections.items(): for connection, handler in self.connections.items():
@ -646,6 +658,7 @@ class HTTPLayer(Layer):
return return
# Can we reuse context.server? # Can we reuse context.server?
can_reuse_context_connection = ( can_reuse_context_connection = (
self.context.server not in self.connections and
self.context.server.connected and self.context.server.connected and
self.context.server.tls == event.tls self.context.server.tls == event.tls
) )
@ -661,15 +674,6 @@ class HTTPLayer(Layer):
open_command.blocking = object() open_command.blocking = object()
yield open_command yield open_command
def make_stream(self) -> HttpStream:
ctx = self.context.fork()
stream = HttpStream(ctx)
if self.debug:
stream.debug = self.debug + " "
self.event_to_child(stream, events.Start())
return stream
def make_http_connection(self, connection: Server) -> None: def make_http_connection(self, connection: Server) -> None:
if connection.tls and not connection.tls_established: if connection.tls and not connection.tls_established:
new_command = EstablishServerTLS(connection) new_command = EstablishServerTLS(connection)

View File

@ -1,6 +1,6 @@
import os import os
import struct import struct
from typing import Any, Generator, Iterator, Optional from typing import Any, Dict, Generator, Iterator, Optional, Tuple
from OpenSSL import SSL from OpenSSL import SSL
@ -100,131 +100,154 @@ class EstablishServerTLSReply(events.CommandReply):
class _TLSLayer(layer.Layer): class _TLSLayer(layer.Layer):
conn: Optional[context.Connection] = None tls: Dict[context.Connection, SSL.Connection]
tls_conn: Optional[SSL.Connection] = None
child_layer: layer.Layer child_layer: layer.Layer
def __init__(self, context: context.Context):
super().__init__(context)
self.tls = {}
def __repr__(self): def __repr__(self):
if self.conn is None: if not self.tls:
state = "inactive" state = "inactive"
elif self.conn.tls_established:
state = f"passthrough {self.conn.sni}, {self.conn.alpn}"
else: else:
state = f"negotiating {self.conn.sni}, {self.conn.alpn}" conn_states = []
for conn in self.tls:
if conn.tls_established:
conn_states.append(f"passthrough {conn.sni} {conn.alpn}")
else:
conn_states.append(f"negotiating {conn.sni} {conn.alpn}")
state = ", ".join(conn_states)
return f"{type(self).__name__}({state})" return f"{type(self).__name__}({state})"
def tls_interact(self): def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator:
while True: while True:
try: try:
data = self.tls_conn.bio_read(65535) data = self.tls[conn].bio_read(65535)
except SSL.WantReadError: except SSL.WantReadError:
# Okay, nothing more waiting to be sent. # Okay, nothing more waiting to be sent.
return return
else: else:
yield commands.SendData(self.conn, data) yield commands.SendData(conn, data)
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]: def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
commands.Command, Any, Tuple[bool, Optional[str]]]:
# bio_write errors for b"", so we need to check first if we actually received something. # bio_write errors for b"", so we need to check first if we actually received something.
if data: if data:
self.tls_conn.bio_write(data) self.tls[conn].bio_write(data)
try: try:
self.tls_conn.do_handshake() self.tls[conn].do_handshake()
except SSL.WantReadError: except SSL.WantReadError:
yield from self.tls_interact() yield from self.tls_interact(conn)
return False return False, None
except SSL.ZeroReturnError: except SSL.ZeroReturnError as e:
raise # TODO: Figure out what to do when handshake fails. return False, repr(e)
else: else:
self.conn.tls_established = True conn.tls_established = True
self.conn.alpn = self.tls_conn.get_alpn_proto_negotiated() conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
yield commands.Log(f"TLS established: {self.conn}") yield commands.Log(f"TLS established: {conn}")
yield from self.receive(b"") yield from self.receive(conn, b"")
# TODO: Set all other connection attributes here # TODO: Set all other connection attributes here
# there might already be data in the OpenSSL BIO, so we need to trigger its processing. # there might already be data in the OpenSSL BIO, so we need to trigger its processing.
return True return True, None
def receive(self, data: bytes): def receive(self, conn: context.Connection, data: bytes):
if data: if data:
self.tls_conn.bio_write(data) self.tls[conn].bio_write(data)
yield from self.tls_interact() yield from self.tls_interact(conn)
plaintext = bytearray() plaintext = bytearray()
close = False
while True: while True:
try: try:
plaintext.extend(self.tls_conn.recv(65535)) plaintext.extend(self.tls[conn].recv(65535))
except (SSL.WantReadError, SSL.ZeroReturnError): except SSL.WantReadError:
break
except SSL.ZeroReturnError:
close = True
break break
if plaintext: if plaintext:
evt = events.DataReceived(self.conn, bytes(plaintext)) yield from self.event_to_child(
yield from self.event_to_child(evt) events.DataReceived(conn, bytes(plaintext))
)
if close:
conn.state &= ~context.ConnectionState.CAN_READ
yield commands.Log(f"TLS close_notify {conn=}")
yield from self.event_to_child(
events.ConnectionClosed(conn)
)
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator: def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
for command in self.child_layer.handle_event(event): for command in self.child_layer.handle_event(event):
if isinstance(command, commands.SendData) and command.connection == self.conn: if isinstance(command, commands.SendData) and command.connection in self.tls:
self.tls_conn.sendall(command.data) self.tls[command.connection].sendall(command.data)
yield from self.tls_interact() yield from self.tls_interact(command.connection)
else: else:
yield command yield command
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator: def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.DataReceived) and event.connection == self.conn: if isinstance(event, events.DataReceived) and event.connection in self.tls:
if not self.conn.tls_established: if not event.connection.tls_established:
yield from self.negotiate(event.data) yield from self.negotiate(event.connection, event.data)
else: else:
yield from self.receive(event.data) yield from self.receive(event.connection, event.data)
elif (
isinstance(event, events.ConnectionClosed) and
event.connection in self.tls and
self.tls[event.connection].get_shutdown() & SSL.RECEIVED_SHUTDOWN
):
pass # We have already dispatched a ConnectionClosed to the child layer.
else: else:
yield from self.event_to_child(event) yield from self.event_to_child(event)
class ServerTLSLayer(_TLSLayer): class ServerTLSLayer(_TLSLayer):
""" """
This layer manages TLS for a single server connection. This layer manages TLS for potentially multiple server connections.
""" """
command_to_reply_to: Optional[EstablishServerTLS] = None command_to_reply_to: Dict[context.Connection, EstablishServerTLS]
def __init__(self, context: context.Context): def __init__(self, context: context.Context):
super().__init__(context) super().__init__(context)
self.command_to_reply_to = {}
self.child_layer = layer.NextLayer(self.context) self.child_layer = layer.NextLayer(self.context)
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]: def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
done = yield from super().negotiate(data) commands.Command, Any, Tuple[bool, Optional[str]]]:
if done: done, err = yield from super().negotiate(conn, data)
assert self.command_to_reply_to if done or err:
yield from self.event_to_child(EstablishServerTLSReply(self.command_to_reply_to, None)) cmd = self.command_to_reply_to.pop(conn)
self.command_to_reply_to = None yield from self.event_to_child(EstablishServerTLSReply(cmd, err))
return done return done, err
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator: def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
for command in super().event_to_child(event): for command in super().event_to_child(event):
if isinstance(command, EstablishServerTLS): if isinstance(command, EstablishServerTLS):
assert isinstance(command.connection, context.Server) self.command_to_reply_to[command.connection] = command
assert not self.command_to_reply_to
self.command_to_reply_to = command
yield from self.start_server_tls(command.connection) yield from self.start_server_tls(command.connection)
else: else:
yield command yield command
def start_server_tls(self, server: context.Server): def start_server_tls(self, conn: context.Server):
assert conn not in self.tls
assert conn.connected
ssl_context = SSL.Context(SSL.SSLv23_METHOD) ssl_context = SSL.Context(SSL.SSLv23_METHOD)
if server.alpn_offers: if conn.alpn_offers:
ssl_context.set_alpn_protos(server.alpn_offers) ssl_context.set_alpn_protos(conn.alpn_offers)
self.tls[conn] = SSL.Connection(ssl_context)
assert not self.conn or not self.conn.connected if conn.sni:
assert server.connected if conn.sni is True:
self.conn = server
self.tls_conn = SSL.Connection(ssl_context)
if server.sni:
if server.sni is True:
if self.context.client.sni: if self.context.client.sni:
server.sni = self.context.client.sni conn.sni = self.context.client.sni
else: else:
server.sni = server.address[0] conn.sni = conn.address[0].encode()
self.tls_conn.set_tlsext_host_name(server.sni) self.tls[conn].set_tlsext_host_name(conn.sni)
self.tls_conn.set_connect_state() self.tls[conn].set_connect_state()
yield from self.negotiate(b"") yield from self.negotiate(conn, b"")
class ClientTLSLayer(_TLSLayer): class ClientTLSLayer(_TLSLayer):
@ -249,7 +272,6 @@ class ClientTLSLayer(_TLSLayer):
def __init__(self, context: context.Context): def __init__(self, context: context.Context):
assert isinstance(context.layers[-1], ServerTLSLayer) assert isinstance(context.layers[-1], ServerTLSLayer)
super().__init__(context) super().__init__(context)
self.conn = context.client
self.recv_buffer = bytearray() self.recv_buffer = bytearray()
self.child_layer = layer.NextLayer(self.context) self.child_layer = layer.NextLayer(self.context)
@ -269,7 +291,7 @@ class ClientTLSLayer(_TLSLayer):
try: try:
client_hello = parse_client_hello(self.recv_buffer) client_hello = parse_client_hello(self.recv_buffer)
except ValueError as e: except ValueError as e:
raise NotImplementedError() from e # TODO raise NotImplementedError from e # TODO
if client_hello: if client_hello:
yield commands.Log(f"Client Hello: {client_hello}") yield commands.Log(f"Client Hello: {client_hello}")
@ -292,7 +314,8 @@ class ClientTLSLayer(_TLSLayer):
if client_tls_requires_server_connection and not self.context.server.tls_established: if client_tls_requires_server_connection and not self.context.server.tls_established:
err = yield from self.start_server_tls() err = yield from self.start_server_tls()
if err: if err:
raise NotImplementedError yield commands.Log("Unable to establish TLS connection with server. "
"Trying to establish TLS with client anyway.")
yield from self.start_client_tls() yield from self.start_client_tls()
self._handle_event = super()._handle_event self._handle_event = super()._handle_event
@ -312,7 +335,7 @@ class ClientTLSLayer(_TLSLayer):
err = yield commands.OpenConnection(server) err = yield commands.OpenConnection(server)
if err: if err:
yield commands.Log( yield commands.Log(
"Cannot establish server connection, which is required to establish TLS with the client." f"Cannot establish server connection: {err}"
) )
return err return err
@ -324,11 +347,11 @@ class ClientTLSLayer(_TLSLayer):
err = yield EstablishServerTLS(server) err = yield EstablishServerTLS(server)
if err: if err:
yield commands.Log( yield commands.Log(
"Cannot establish TLS with server, which is required to establish TLS with the client." f"Cannot establish TLS with server: {err}"
) )
return err return err
def start_client_tls(self): def start_client_tls(self) -> commands.TCommandGenerator:
# FIXME: Do this properly. Also adjust error message in negotiate() # FIXME: Do this properly. Also adjust error message in negotiate()
client = self.context.client client = self.context.client
server = self.context.server server = self.context.server
@ -358,17 +381,15 @@ class ClientTLSLayer(_TLSLayer):
context.set_alpn_select_callback(alpn_select_callback) context.set_alpn_select_callback(alpn_select_callback)
self.tls_conn = SSL.Connection(context) self.tls[client] = SSL.Connection(context)
self.tls_conn.set_accept_state() self.tls[client].set_accept_state()
yield from self.negotiate(bytes(self.recv_buffer)) yield from self.negotiate(client, bytes(self.recv_buffer))
self.recv_buffer = None self.recv_buffer.clear()
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]: def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
try: done, err = yield from super().negotiate(conn, data)
done = yield from super().negotiate(data) if err:
return done
except SSL.ZeroReturnError:
yield commands.Log( yield commands.Log(
f"Client TLS Handshake failed. " f"Client TLS Handshake failed. "
f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).", f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).",
@ -376,3 +397,4 @@ class ClientTLSLayer(_TLSLayer):
# TODO: Also use other sources than SNI # TODO: Also use other sources than SNI
) )
yield commands.CloseConnection(self.context.client) yield commands.CloseConnection(self.context.client)
return done

View File

@ -191,7 +191,7 @@ if __name__ == "__main__":
opts = moptions.Options() opts = moptions.Options()
opts.add_option( opts.add_option(
"connection_strategy", str, "lazy", "connection_strategy", str, "eager",
"Determine when server connections should be established.", "Determine when server connections should be established.",
choices=("eager", "lazy") choices=("eager", "lazy")
) )
@ -200,6 +200,7 @@ if __name__ == "__main__":
async def handle(reader, writer): async def handle(reader, writer):
layer_stack = [ layer_stack = [
lambda ctx: layers.ServerTLSLayer(ctx),
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular), lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular),
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx), lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
lambda ctx: layers.ClientTLSLayer(ctx), lambda ctx: layers.ClientTLSLayer(ctx),