[sans-io] tls: handle multiple servers

This commit is contained in:
Maximilian Hils 2019-11-06 18:59:53 +01:00
parent ffb3782618
commit 549e41ee40
3 changed files with 117 additions and 90 deletions

View File

@ -601,8 +601,11 @@ class HTTPLayer(Layer):
self.connections[event.command.connection] = stream
self.event_to_child(stream, event)
elif isinstance(event, events.ConnectionEvent):
handler = self.connections[event.connection]
self.event_to_child(handler, event)
if event.connection == self.context.server and self.context.server not in self.connections:
pass
else:
handler = self.connections[event.connection]
self.event_to_child(handler, event)
else:
raise ValueError(f"Unexpected event: {event}")
@ -624,6 +627,15 @@ class HTTPLayer(Layer):
else:
raise ValueError(f"Unexpected event: {event}")
def make_stream(self) -> HttpStream:
ctx = self.context.fork()
stream = HttpStream(ctx)
if self.debug:
stream.debug = self.debug + " "
self.event_to_child(stream, events.Start())
return stream
def get_connection(self, event: GetHttpConnection):
# Do we already have a connection we can re-use?
for connection, handler in self.connections.items():
@ -646,6 +658,7 @@ class HTTPLayer(Layer):
return
# Can we reuse context.server?
can_reuse_context_connection = (
self.context.server not in self.connections and
self.context.server.connected and
self.context.server.tls == event.tls
)
@ -661,15 +674,6 @@ class HTTPLayer(Layer):
open_command.blocking = object()
yield open_command
def make_stream(self) -> HttpStream:
ctx = self.context.fork()
stream = HttpStream(ctx)
if self.debug:
stream.debug = self.debug + " "
self.event_to_child(stream, events.Start())
return stream
def make_http_connection(self, connection: Server) -> None:
if connection.tls and not connection.tls_established:
new_command = EstablishServerTLS(connection)

View File

@ -1,6 +1,6 @@
import os
import struct
from typing import Any, Generator, Iterator, Optional
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
from OpenSSL import SSL
@ -100,131 +100,154 @@ class EstablishServerTLSReply(events.CommandReply):
class _TLSLayer(layer.Layer):
conn: Optional[context.Connection] = None
tls_conn: Optional[SSL.Connection] = None
tls: Dict[context.Connection, SSL.Connection]
child_layer: layer.Layer
def __init__(self, context: context.Context):
super().__init__(context)
self.tls = {}
def __repr__(self):
if self.conn is None:
if not self.tls:
state = "inactive"
elif self.conn.tls_established:
state = f"passthrough {self.conn.sni}, {self.conn.alpn}"
else:
state = f"negotiating {self.conn.sni}, {self.conn.alpn}"
conn_states = []
for conn in self.tls:
if conn.tls_established:
conn_states.append(f"passthrough {conn.sni} {conn.alpn}")
else:
conn_states.append(f"negotiating {conn.sni} {conn.alpn}")
state = ", ".join(conn_states)
return f"{type(self).__name__}({state})"
def tls_interact(self):
def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator:
while True:
try:
data = self.tls_conn.bio_read(65535)
data = self.tls[conn].bio_read(65535)
except SSL.WantReadError:
# Okay, nothing more waiting to be sent.
return
else:
yield commands.SendData(self.conn, data)
yield commands.SendData(conn, data)
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]:
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
commands.Command, Any, Tuple[bool, Optional[str]]]:
# bio_write errors for b"", so we need to check first if we actually received something.
if data:
self.tls_conn.bio_write(data)
self.tls[conn].bio_write(data)
try:
self.tls_conn.do_handshake()
self.tls[conn].do_handshake()
except SSL.WantReadError:
yield from self.tls_interact()
return False
except SSL.ZeroReturnError:
raise # TODO: Figure out what to do when handshake fails.
yield from self.tls_interact(conn)
return False, None
except SSL.ZeroReturnError as e:
return False, repr(e)
else:
self.conn.tls_established = True
self.conn.alpn = self.tls_conn.get_alpn_proto_negotiated()
yield commands.Log(f"TLS established: {self.conn}")
yield from self.receive(b"")
conn.tls_established = True
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
yield commands.Log(f"TLS established: {conn}")
yield from self.receive(conn, b"")
# TODO: Set all other connection attributes here
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
return True
return True, None
def receive(self, data: bytes):
def receive(self, conn: context.Connection, data: bytes):
if data:
self.tls_conn.bio_write(data)
yield from self.tls_interact()
self.tls[conn].bio_write(data)
yield from self.tls_interact(conn)
plaintext = bytearray()
close = False
while True:
try:
plaintext.extend(self.tls_conn.recv(65535))
except (SSL.WantReadError, SSL.ZeroReturnError):
plaintext.extend(self.tls[conn].recv(65535))
except SSL.WantReadError:
break
except SSL.ZeroReturnError:
close = True
break
if plaintext:
evt = events.DataReceived(self.conn, bytes(plaintext))
yield from self.event_to_child(evt)
yield from self.event_to_child(
events.DataReceived(conn, bytes(plaintext))
)
if close:
conn.state &= ~context.ConnectionState.CAN_READ
yield commands.Log(f"TLS close_notify {conn=}")
yield from self.event_to_child(
events.ConnectionClosed(conn)
)
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
for command in self.child_layer.handle_event(event):
if isinstance(command, commands.SendData) and command.connection == self.conn:
self.tls_conn.sendall(command.data)
yield from self.tls_interact()
if isinstance(command, commands.SendData) and command.connection in self.tls:
self.tls[command.connection].sendall(command.data)
yield from self.tls_interact(command.connection)
else:
yield command
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
if isinstance(event, events.DataReceived) and event.connection == self.conn:
if not self.conn.tls_established:
yield from self.negotiate(event.data)
if isinstance(event, events.DataReceived) and event.connection in self.tls:
if not event.connection.tls_established:
yield from self.negotiate(event.connection, event.data)
else:
yield from self.receive(event.data)
yield from self.receive(event.connection, event.data)
elif (
isinstance(event, events.ConnectionClosed) and
event.connection in self.tls and
self.tls[event.connection].get_shutdown() & SSL.RECEIVED_SHUTDOWN
):
pass # We have already dispatched a ConnectionClosed to the child layer.
else:
yield from self.event_to_child(event)
class ServerTLSLayer(_TLSLayer):
"""
This layer manages TLS for a single server connection.
This layer manages TLS for potentially multiple server connections.
"""
command_to_reply_to: Optional[EstablishServerTLS] = None
command_to_reply_to: Dict[context.Connection, EstablishServerTLS]
def __init__(self, context: context.Context):
super().__init__(context)
self.command_to_reply_to = {}
self.child_layer = layer.NextLayer(self.context)
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]:
done = yield from super().negotiate(data)
if done:
assert self.command_to_reply_to
yield from self.event_to_child(EstablishServerTLSReply(self.command_to_reply_to, None))
self.command_to_reply_to = None
return done
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
commands.Command, Any, Tuple[bool, Optional[str]]]:
done, err = yield from super().negotiate(conn, data)
if done or err:
cmd = self.command_to_reply_to.pop(conn)
yield from self.event_to_child(EstablishServerTLSReply(cmd, err))
return done, err
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
for command in super().event_to_child(event):
if isinstance(command, EstablishServerTLS):
assert isinstance(command.connection, context.Server)
assert not self.command_to_reply_to
self.command_to_reply_to = command
self.command_to_reply_to[command.connection] = command
yield from self.start_server_tls(command.connection)
else:
yield command
def start_server_tls(self, server: context.Server):
def start_server_tls(self, conn: context.Server):
assert conn not in self.tls
assert conn.connected
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
if server.alpn_offers:
ssl_context.set_alpn_protos(server.alpn_offers)
if conn.alpn_offers:
ssl_context.set_alpn_protos(conn.alpn_offers)
self.tls[conn] = SSL.Connection(ssl_context)
assert not self.conn or not self.conn.connected
assert server.connected
self.conn = server
self.tls_conn = SSL.Connection(ssl_context)
if server.sni:
if server.sni is True:
if conn.sni:
if conn.sni is True:
if self.context.client.sni:
server.sni = self.context.client.sni
conn.sni = self.context.client.sni
else:
server.sni = server.address[0]
self.tls_conn.set_tlsext_host_name(server.sni)
self.tls_conn.set_connect_state()
conn.sni = conn.address[0].encode()
self.tls[conn].set_tlsext_host_name(conn.sni)
self.tls[conn].set_connect_state()
yield from self.negotiate(b"")
yield from self.negotiate(conn, b"")
class ClientTLSLayer(_TLSLayer):
@ -249,7 +272,6 @@ class ClientTLSLayer(_TLSLayer):
def __init__(self, context: context.Context):
assert isinstance(context.layers[-1], ServerTLSLayer)
super().__init__(context)
self.conn = context.client
self.recv_buffer = bytearray()
self.child_layer = layer.NextLayer(self.context)
@ -269,7 +291,7 @@ class ClientTLSLayer(_TLSLayer):
try:
client_hello = parse_client_hello(self.recv_buffer)
except ValueError as e:
raise NotImplementedError() from e # TODO
raise NotImplementedError from e # TODO
if client_hello:
yield commands.Log(f"Client Hello: {client_hello}")
@ -292,7 +314,8 @@ class ClientTLSLayer(_TLSLayer):
if client_tls_requires_server_connection and not self.context.server.tls_established:
err = yield from self.start_server_tls()
if err:
raise NotImplementedError
yield commands.Log("Unable to establish TLS connection with server. "
"Trying to establish TLS with client anyway.")
yield from self.start_client_tls()
self._handle_event = super()._handle_event
@ -312,7 +335,7 @@ class ClientTLSLayer(_TLSLayer):
err = yield commands.OpenConnection(server)
if err:
yield commands.Log(
"Cannot establish server connection, which is required to establish TLS with the client."
f"Cannot establish server connection: {err}"
)
return err
@ -324,11 +347,11 @@ class ClientTLSLayer(_TLSLayer):
err = yield EstablishServerTLS(server)
if err:
yield commands.Log(
"Cannot establish TLS with server, which is required to establish TLS with the client."
f"Cannot establish TLS with server: {err}"
)
return err
def start_client_tls(self):
def start_client_tls(self) -> commands.TCommandGenerator:
# FIXME: Do this properly. Also adjust error message in negotiate()
client = self.context.client
server = self.context.server
@ -358,17 +381,15 @@ class ClientTLSLayer(_TLSLayer):
context.set_alpn_select_callback(alpn_select_callback)
self.tls_conn = SSL.Connection(context)
self.tls_conn.set_accept_state()
self.tls[client] = SSL.Connection(context)
self.tls[client].set_accept_state()
yield from self.negotiate(bytes(self.recv_buffer))
self.recv_buffer = None
yield from self.negotiate(client, bytes(self.recv_buffer))
self.recv_buffer.clear()
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]:
try:
done = yield from super().negotiate(data)
return done
except SSL.ZeroReturnError:
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
done, err = yield from super().negotiate(conn, data)
if err:
yield commands.Log(
f"Client TLS Handshake failed. "
f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).",
@ -376,3 +397,4 @@ class ClientTLSLayer(_TLSLayer):
# TODO: Also use other sources than SNI
)
yield commands.CloseConnection(self.context.client)
return done

View File

@ -191,7 +191,7 @@ if __name__ == "__main__":
opts = moptions.Options()
opts.add_option(
"connection_strategy", str, "lazy",
"connection_strategy", str, "eager",
"Determine when server connections should be established.",
choices=("eager", "lazy")
)
@ -200,6 +200,7 @@ if __name__ == "__main__":
async def handle(reader, writer):
layer_stack = [
lambda ctx: layers.ServerTLSLayer(ctx),
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular),
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
lambda ctx: layers.ClientTLSLayer(ctx),