mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] tls: handle multiple servers
This commit is contained in:
parent
ffb3782618
commit
549e41ee40
@ -601,6 +601,9 @@ class HTTPLayer(Layer):
|
|||||||
self.connections[event.command.connection] = stream
|
self.connections[event.command.connection] = stream
|
||||||
self.event_to_child(stream, event)
|
self.event_to_child(stream, event)
|
||||||
elif isinstance(event, events.ConnectionEvent):
|
elif isinstance(event, events.ConnectionEvent):
|
||||||
|
if event.connection == self.context.server and self.context.server not in self.connections:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
handler = self.connections[event.connection]
|
handler = self.connections[event.connection]
|
||||||
self.event_to_child(handler, event)
|
self.event_to_child(handler, event)
|
||||||
else:
|
else:
|
||||||
@ -624,6 +627,15 @@ class HTTPLayer(Layer):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
raise ValueError(f"Unexpected event: {event}")
|
||||||
|
|
||||||
|
def make_stream(self) -> HttpStream:
|
||||||
|
ctx = self.context.fork()
|
||||||
|
|
||||||
|
stream = HttpStream(ctx)
|
||||||
|
if self.debug:
|
||||||
|
stream.debug = self.debug + " "
|
||||||
|
self.event_to_child(stream, events.Start())
|
||||||
|
return stream
|
||||||
|
|
||||||
def get_connection(self, event: GetHttpConnection):
|
def get_connection(self, event: GetHttpConnection):
|
||||||
# Do we already have a connection we can re-use?
|
# Do we already have a connection we can re-use?
|
||||||
for connection, handler in self.connections.items():
|
for connection, handler in self.connections.items():
|
||||||
@ -646,6 +658,7 @@ class HTTPLayer(Layer):
|
|||||||
return
|
return
|
||||||
# Can we reuse context.server?
|
# Can we reuse context.server?
|
||||||
can_reuse_context_connection = (
|
can_reuse_context_connection = (
|
||||||
|
self.context.server not in self.connections and
|
||||||
self.context.server.connected and
|
self.context.server.connected and
|
||||||
self.context.server.tls == event.tls
|
self.context.server.tls == event.tls
|
||||||
)
|
)
|
||||||
@ -661,15 +674,6 @@ class HTTPLayer(Layer):
|
|||||||
open_command.blocking = object()
|
open_command.blocking = object()
|
||||||
yield open_command
|
yield open_command
|
||||||
|
|
||||||
def make_stream(self) -> HttpStream:
|
|
||||||
ctx = self.context.fork()
|
|
||||||
|
|
||||||
stream = HttpStream(ctx)
|
|
||||||
if self.debug:
|
|
||||||
stream.debug = self.debug + " "
|
|
||||||
self.event_to_child(stream, events.Start())
|
|
||||||
return stream
|
|
||||||
|
|
||||||
def make_http_connection(self, connection: Server) -> None:
|
def make_http_connection(self, connection: Server) -> None:
|
||||||
if connection.tls and not connection.tls_established:
|
if connection.tls and not connection.tls_established:
|
||||||
new_command = EstablishServerTLS(connection)
|
new_command = EstablishServerTLS(connection)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
from typing import Any, Generator, Iterator, Optional
|
from typing import Any, Dict, Generator, Iterator, Optional, Tuple
|
||||||
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
@ -100,131 +100,154 @@ class EstablishServerTLSReply(events.CommandReply):
|
|||||||
|
|
||||||
|
|
||||||
class _TLSLayer(layer.Layer):
|
class _TLSLayer(layer.Layer):
|
||||||
conn: Optional[context.Connection] = None
|
tls: Dict[context.Connection, SSL.Connection]
|
||||||
tls_conn: Optional[SSL.Connection] = None
|
|
||||||
child_layer: layer.Layer
|
child_layer: layer.Layer
|
||||||
|
|
||||||
|
def __init__(self, context: context.Context):
|
||||||
|
super().__init__(context)
|
||||||
|
self.tls = {}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.conn is None:
|
if not self.tls:
|
||||||
state = "inactive"
|
state = "inactive"
|
||||||
elif self.conn.tls_established:
|
|
||||||
state = f"passthrough {self.conn.sni}, {self.conn.alpn}"
|
|
||||||
else:
|
else:
|
||||||
state = f"negotiating {self.conn.sni}, {self.conn.alpn}"
|
conn_states = []
|
||||||
|
for conn in self.tls:
|
||||||
|
if conn.tls_established:
|
||||||
|
conn_states.append(f"passthrough {conn.sni} {conn.alpn}")
|
||||||
|
else:
|
||||||
|
conn_states.append(f"negotiating {conn.sni} {conn.alpn}")
|
||||||
|
state = ", ".join(conn_states)
|
||||||
return f"{type(self).__name__}({state})"
|
return f"{type(self).__name__}({state})"
|
||||||
|
|
||||||
def tls_interact(self):
|
def tls_interact(self, conn: context.Connection) -> commands.TCommandGenerator:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
data = self.tls_conn.bio_read(65535)
|
data = self.tls[conn].bio_read(65535)
|
||||||
except SSL.WantReadError:
|
except SSL.WantReadError:
|
||||||
# Okay, nothing more waiting to be sent.
|
# Okay, nothing more waiting to be sent.
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
yield commands.SendData(self.conn, data)
|
yield commands.SendData(conn, data)
|
||||||
|
|
||||||
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]:
|
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
|
||||||
|
commands.Command, Any, Tuple[bool, Optional[str]]]:
|
||||||
# bio_write errors for b"", so we need to check first if we actually received something.
|
# bio_write errors for b"", so we need to check first if we actually received something.
|
||||||
if data:
|
if data:
|
||||||
self.tls_conn.bio_write(data)
|
self.tls[conn].bio_write(data)
|
||||||
try:
|
try:
|
||||||
self.tls_conn.do_handshake()
|
self.tls[conn].do_handshake()
|
||||||
except SSL.WantReadError:
|
except SSL.WantReadError:
|
||||||
yield from self.tls_interact()
|
yield from self.tls_interact(conn)
|
||||||
return False
|
return False, None
|
||||||
except SSL.ZeroReturnError:
|
except SSL.ZeroReturnError as e:
|
||||||
raise # TODO: Figure out what to do when handshake fails.
|
return False, repr(e)
|
||||||
else:
|
else:
|
||||||
self.conn.tls_established = True
|
conn.tls_established = True
|
||||||
self.conn.alpn = self.tls_conn.get_alpn_proto_negotiated()
|
conn.alpn = self.tls[conn].get_alpn_proto_negotiated()
|
||||||
yield commands.Log(f"TLS established: {self.conn}")
|
yield commands.Log(f"TLS established: {conn}")
|
||||||
yield from self.receive(b"")
|
yield from self.receive(conn, b"")
|
||||||
# TODO: Set all other connection attributes here
|
# TODO: Set all other connection attributes here
|
||||||
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
|
# there might already be data in the OpenSSL BIO, so we need to trigger its processing.
|
||||||
return True
|
return True, None
|
||||||
|
|
||||||
def receive(self, data: bytes):
|
def receive(self, conn: context.Connection, data: bytes):
|
||||||
if data:
|
if data:
|
||||||
self.tls_conn.bio_write(data)
|
self.tls[conn].bio_write(data)
|
||||||
yield from self.tls_interact()
|
yield from self.tls_interact(conn)
|
||||||
|
|
||||||
plaintext = bytearray()
|
plaintext = bytearray()
|
||||||
|
close = False
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
plaintext.extend(self.tls_conn.recv(65535))
|
plaintext.extend(self.tls[conn].recv(65535))
|
||||||
except (SSL.WantReadError, SSL.ZeroReturnError):
|
except SSL.WantReadError:
|
||||||
|
break
|
||||||
|
except SSL.ZeroReturnError:
|
||||||
|
close = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if plaintext:
|
if plaintext:
|
||||||
evt = events.DataReceived(self.conn, bytes(plaintext))
|
yield from self.event_to_child(
|
||||||
yield from self.event_to_child(evt)
|
events.DataReceived(conn, bytes(plaintext))
|
||||||
|
)
|
||||||
|
if close:
|
||||||
|
conn.state &= ~context.ConnectionState.CAN_READ
|
||||||
|
yield commands.Log(f"TLS close_notify {conn=}")
|
||||||
|
yield from self.event_to_child(
|
||||||
|
events.ConnectionClosed(conn)
|
||||||
|
)
|
||||||
|
|
||||||
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
for command in self.child_layer.handle_event(event):
|
for command in self.child_layer.handle_event(event):
|
||||||
if isinstance(command, commands.SendData) and command.connection == self.conn:
|
if isinstance(command, commands.SendData) and command.connection in self.tls:
|
||||||
self.tls_conn.sendall(command.data)
|
self.tls[command.connection].sendall(command.data)
|
||||||
yield from self.tls_interact()
|
yield from self.tls_interact(command.connection)
|
||||||
else:
|
else:
|
||||||
yield command
|
yield command
|
||||||
|
|
||||||
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
def _handle_event(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
if isinstance(event, events.DataReceived) and event.connection == self.conn:
|
if isinstance(event, events.DataReceived) and event.connection in self.tls:
|
||||||
if not self.conn.tls_established:
|
if not event.connection.tls_established:
|
||||||
yield from self.negotiate(event.data)
|
yield from self.negotiate(event.connection, event.data)
|
||||||
else:
|
else:
|
||||||
yield from self.receive(event.data)
|
yield from self.receive(event.connection, event.data)
|
||||||
|
elif (
|
||||||
|
isinstance(event, events.ConnectionClosed) and
|
||||||
|
event.connection in self.tls and
|
||||||
|
self.tls[event.connection].get_shutdown() & SSL.RECEIVED_SHUTDOWN
|
||||||
|
):
|
||||||
|
pass # We have already dispatched a ConnectionClosed to the child layer.
|
||||||
else:
|
else:
|
||||||
yield from self.event_to_child(event)
|
yield from self.event_to_child(event)
|
||||||
|
|
||||||
|
|
||||||
class ServerTLSLayer(_TLSLayer):
|
class ServerTLSLayer(_TLSLayer):
|
||||||
"""
|
"""
|
||||||
This layer manages TLS for a single server connection.
|
This layer manages TLS for potentially multiple server connections.
|
||||||
"""
|
"""
|
||||||
command_to_reply_to: Optional[EstablishServerTLS] = None
|
command_to_reply_to: Dict[context.Connection, EstablishServerTLS]
|
||||||
|
|
||||||
def __init__(self, context: context.Context):
|
def __init__(self, context: context.Context):
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
|
self.command_to_reply_to = {}
|
||||||
self.child_layer = layer.NextLayer(self.context)
|
self.child_layer = layer.NextLayer(self.context)
|
||||||
|
|
||||||
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]:
|
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[
|
||||||
done = yield from super().negotiate(data)
|
commands.Command, Any, Tuple[bool, Optional[str]]]:
|
||||||
if done:
|
done, err = yield from super().negotiate(conn, data)
|
||||||
assert self.command_to_reply_to
|
if done or err:
|
||||||
yield from self.event_to_child(EstablishServerTLSReply(self.command_to_reply_to, None))
|
cmd = self.command_to_reply_to.pop(conn)
|
||||||
self.command_to_reply_to = None
|
yield from self.event_to_child(EstablishServerTLSReply(cmd, err))
|
||||||
return done
|
return done, err
|
||||||
|
|
||||||
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
def event_to_child(self, event: events.Event) -> commands.TCommandGenerator:
|
||||||
for command in super().event_to_child(event):
|
for command in super().event_to_child(event):
|
||||||
if isinstance(command, EstablishServerTLS):
|
if isinstance(command, EstablishServerTLS):
|
||||||
assert isinstance(command.connection, context.Server)
|
self.command_to_reply_to[command.connection] = command
|
||||||
assert not self.command_to_reply_to
|
|
||||||
self.command_to_reply_to = command
|
|
||||||
yield from self.start_server_tls(command.connection)
|
yield from self.start_server_tls(command.connection)
|
||||||
else:
|
else:
|
||||||
yield command
|
yield command
|
||||||
|
|
||||||
def start_server_tls(self, server: context.Server):
|
def start_server_tls(self, conn: context.Server):
|
||||||
|
assert conn not in self.tls
|
||||||
|
assert conn.connected
|
||||||
|
|
||||||
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
ssl_context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
if server.alpn_offers:
|
if conn.alpn_offers:
|
||||||
ssl_context.set_alpn_protos(server.alpn_offers)
|
ssl_context.set_alpn_protos(conn.alpn_offers)
|
||||||
|
self.tls[conn] = SSL.Connection(ssl_context)
|
||||||
|
|
||||||
assert not self.conn or not self.conn.connected
|
if conn.sni:
|
||||||
assert server.connected
|
if conn.sni is True:
|
||||||
self.conn = server
|
|
||||||
self.tls_conn = SSL.Connection(ssl_context)
|
|
||||||
|
|
||||||
if server.sni:
|
|
||||||
if server.sni is True:
|
|
||||||
if self.context.client.sni:
|
if self.context.client.sni:
|
||||||
server.sni = self.context.client.sni
|
conn.sni = self.context.client.sni
|
||||||
else:
|
else:
|
||||||
server.sni = server.address[0]
|
conn.sni = conn.address[0].encode()
|
||||||
self.tls_conn.set_tlsext_host_name(server.sni)
|
self.tls[conn].set_tlsext_host_name(conn.sni)
|
||||||
self.tls_conn.set_connect_state()
|
self.tls[conn].set_connect_state()
|
||||||
|
|
||||||
yield from self.negotiate(b"")
|
yield from self.negotiate(conn, b"")
|
||||||
|
|
||||||
|
|
||||||
class ClientTLSLayer(_TLSLayer):
|
class ClientTLSLayer(_TLSLayer):
|
||||||
@ -249,7 +272,6 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
def __init__(self, context: context.Context):
|
def __init__(self, context: context.Context):
|
||||||
assert isinstance(context.layers[-1], ServerTLSLayer)
|
assert isinstance(context.layers[-1], ServerTLSLayer)
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.conn = context.client
|
|
||||||
self.recv_buffer = bytearray()
|
self.recv_buffer = bytearray()
|
||||||
self.child_layer = layer.NextLayer(self.context)
|
self.child_layer = layer.NextLayer(self.context)
|
||||||
|
|
||||||
@ -269,7 +291,7 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
try:
|
try:
|
||||||
client_hello = parse_client_hello(self.recv_buffer)
|
client_hello = parse_client_hello(self.recv_buffer)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise NotImplementedError() from e # TODO
|
raise NotImplementedError from e # TODO
|
||||||
|
|
||||||
if client_hello:
|
if client_hello:
|
||||||
yield commands.Log(f"Client Hello: {client_hello}")
|
yield commands.Log(f"Client Hello: {client_hello}")
|
||||||
@ -292,7 +314,8 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
if client_tls_requires_server_connection and not self.context.server.tls_established:
|
if client_tls_requires_server_connection and not self.context.server.tls_established:
|
||||||
err = yield from self.start_server_tls()
|
err = yield from self.start_server_tls()
|
||||||
if err:
|
if err:
|
||||||
raise NotImplementedError
|
yield commands.Log("Unable to establish TLS connection with server. "
|
||||||
|
"Trying to establish TLS with client anyway.")
|
||||||
|
|
||||||
yield from self.start_client_tls()
|
yield from self.start_client_tls()
|
||||||
self._handle_event = super()._handle_event
|
self._handle_event = super()._handle_event
|
||||||
@ -312,7 +335,7 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
err = yield commands.OpenConnection(server)
|
err = yield commands.OpenConnection(server)
|
||||||
if err:
|
if err:
|
||||||
yield commands.Log(
|
yield commands.Log(
|
||||||
"Cannot establish server connection, which is required to establish TLS with the client."
|
f"Cannot establish server connection: {err}"
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
|
|
||||||
@ -324,11 +347,11 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
err = yield EstablishServerTLS(server)
|
err = yield EstablishServerTLS(server)
|
||||||
if err:
|
if err:
|
||||||
yield commands.Log(
|
yield commands.Log(
|
||||||
"Cannot establish TLS with server, which is required to establish TLS with the client."
|
f"Cannot establish TLS with server: {err}"
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
|
|
||||||
def start_client_tls(self):
|
def start_client_tls(self) -> commands.TCommandGenerator:
|
||||||
# FIXME: Do this properly. Also adjust error message in negotiate()
|
# FIXME: Do this properly. Also adjust error message in negotiate()
|
||||||
client = self.context.client
|
client = self.context.client
|
||||||
server = self.context.server
|
server = self.context.server
|
||||||
@ -358,17 +381,15 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
|
|
||||||
context.set_alpn_select_callback(alpn_select_callback)
|
context.set_alpn_select_callback(alpn_select_callback)
|
||||||
|
|
||||||
self.tls_conn = SSL.Connection(context)
|
self.tls[client] = SSL.Connection(context)
|
||||||
self.tls_conn.set_accept_state()
|
self.tls[client].set_accept_state()
|
||||||
|
|
||||||
yield from self.negotiate(bytes(self.recv_buffer))
|
yield from self.negotiate(client, bytes(self.recv_buffer))
|
||||||
self.recv_buffer = None
|
self.recv_buffer.clear()
|
||||||
|
|
||||||
def negotiate(self, data: bytes) -> Generator[commands.Command, Any, bool]:
|
def negotiate(self, conn: context.Connection, data: bytes) -> Generator[commands.Command, Any, bool]:
|
||||||
try:
|
done, err = yield from super().negotiate(conn, data)
|
||||||
done = yield from super().negotiate(data)
|
if err:
|
||||||
return done
|
|
||||||
except SSL.ZeroReturnError:
|
|
||||||
yield commands.Log(
|
yield commands.Log(
|
||||||
f"Client TLS Handshake failed. "
|
f"Client TLS Handshake failed. "
|
||||||
f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).",
|
f"The client may not trust the proxy's certificate (SNI: {self.context.client.sni}).",
|
||||||
@ -376,3 +397,4 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
# TODO: Also use other sources than SNI
|
# TODO: Also use other sources than SNI
|
||||||
)
|
)
|
||||||
yield commands.CloseConnection(self.context.client)
|
yield commands.CloseConnection(self.context.client)
|
||||||
|
return done
|
||||||
|
@ -191,7 +191,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
opts = moptions.Options()
|
opts = moptions.Options()
|
||||||
opts.add_option(
|
opts.add_option(
|
||||||
"connection_strategy", str, "lazy",
|
"connection_strategy", str, "eager",
|
||||||
"Determine when server connections should be established.",
|
"Determine when server connections should be established.",
|
||||||
choices=("eager", "lazy")
|
choices=("eager", "lazy")
|
||||||
)
|
)
|
||||||
@ -200,6 +200,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
async def handle(reader, writer):
|
async def handle(reader, writer):
|
||||||
layer_stack = [
|
layer_stack = [
|
||||||
|
lambda ctx: layers.ServerTLSLayer(ctx),
|
||||||
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular),
|
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular),
|
||||||
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
|
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
|
||||||
lambda ctx: layers.ClientTLSLayer(ctx),
|
lambda ctx: layers.ClientTLSLayer(ctx),
|
||||||
|
Loading…
Reference in New Issue
Block a user