Merge pull request #4529 from mhils/eager-sans-io

Sans-IO Improvements: Connection Strategy
This commit is contained in:
Maximilian Hils 2021-03-30 10:16:57 +02:00 committed by GitHub
commit f94a9a3c9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 141 additions and 41 deletions

View File

@ -105,8 +105,6 @@ class NextLayer:
def s(*layers): def s(*layers):
return stack_match(context, layers) return stack_match(context, layers)
top_layer = context.layers[-1]
# 1. check for --ignore/--allow # 1. check for --ignore/--allow
ignore = self.ignore_connection(context.server.address, data_client) ignore = self.ignore_connection(context.server.address, data_client)
if ignore is True: if ignore is True:
@ -116,13 +114,17 @@ class NextLayer:
# 2. Check for TLS # 2. Check for TLS
if client_tls: if client_tls:
# client tls requires a server tls layer as parent layer # client tls usually requires a server tls layer as parent layer, except:
# reverse proxy mode manages this itself. # - reverse proxy mode manages this itself.
# a secure web proxy doesn't have a server part. # - a secure web proxy doesn't have a server part.
if isinstance(top_layer, layers.ServerTLSLayer) or s(modes.ReverseProxy) or s(modes.HttpProxy): if s(modes.ReverseProxy) or s(modes.HttpProxy):
return layers.ClientTLSLayer(context) return layers.ClientTLSLayer(context)
else: else:
return layers.ServerTLSLayer(context) # We already assign the next layer here os that ServerTLSLayer
# knows that it can safely wait for a ClientHello.
ret = layers.ServerTLSLayer(context)
ret.child_layer = layers.ClientTLSLayer(context)
return ret
# 3. Setup the HTTP layer for a regular HTTP proxy or an upstream proxy. # 3. Setup the HTTP layer for a regular HTTP proxy or an upstream proxy.
if any([ if any([

View File

@ -83,8 +83,11 @@ class Proxyserver:
def load(self, loader): def load(self, loader):
loader.add_option( loader.add_option(
"connection_strategy", str, "lazy", "connection_strategy", str, "eager",
"Determine when server connections should be established.", "Determine when server connections should be established. When set to lazy, mitmproxy "
"tries to defer establishing an upstream connection as long as possible. This makes it possible to "
"use server replay while being offline. When set to eager, mitmproxy can detect protocols with "
"server-side greetings, as well as accurately mirror TLS ALPN negotiation.",
choices=("eager", "lazy") choices=("eager", "lazy")
) )
loader.add_option( loader.add_option(

View File

@ -34,6 +34,10 @@ def alpn_select_callback(conn: SSL.Connection, options: List[bytes]) -> Any:
http2 = app_data["http2"] http2 = app_data["http2"]
if server_alpn and server_alpn in options: if server_alpn and server_alpn in options:
return server_alpn return server_alpn
if server_alpn == b"":
# We do have a server connection, but the remote server refused to negotiate a protocol:
# We need to mirror this on the client connection.
return SSL.NO_OVERLAPPING_PROTOCOLS
http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS
for alpn in options: # client sends in order of preference, so we are nice and respect that. for alpn in options: # client sends in order of preference, so we are nice and respect that.
if alpn in http_alpns: if alpn in http_alpns:

View File

@ -43,7 +43,13 @@ class Connection(serializable.Serializable, metaclass=ABCMeta):
sockname: Optional[Address] sockname: Optional[Address]
"""Our local `(ip, port)` tuple for this connection.""" """Our local `(ip, port)` tuple for this connection."""
error: Optional[str] = None error: Optional[str] = None
"""A string describing the connection error.""" """
A string describing a general error with connections to this address.
The purpose of this property is to signal that new connections to the particular endpoint should not be attempted,
for example because it uses an untrusted TLS certificate. Regular (unexpected) disconnects do not set the error
property. This property is only reused per client connection.
"""
tls: bool = False tls: bool = False
""" """

View File

@ -15,4 +15,7 @@ class ViewQuery(base.View):
return "Query", base.format_text("") return "Query", base.format_text("")
def render_priority(self, data: bytes, *, http_message: Optional[http.Message] = None, **metadata) -> float: def render_priority(self, data: bytes, *, http_message: Optional[http.Message] = None, **metadata) -> float:
return 0.3 * float(bool(getattr(http_message, "query", False))) return 0.3 * float(bool(
getattr(http_message, "query", False)
and not data
))

View File

@ -84,7 +84,7 @@ class CommandCompleted(Event):
command_reply_subclasses[command_cls] = cls command_reply_subclasses[command_cls] = cls
def __repr__(self): def __repr__(self):
return f"Reply({repr(self.command)})" return f"Reply({repr(self.command)},{repr(self.reply)})"
command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandCompleted]] = {} command_reply_subclasses: typing.Dict[commands.Command, typing.Type[CommandCompleted]] = {}

View File

@ -683,32 +683,40 @@ class HttpLayer(layer.Layer):
# Do we already have a connection we can re-use? # Do we already have a connection we can re-use?
if reuse: if reuse:
for connection in self.connections: for connection in self.connections:
# see "tricky multiplexing edge case" in make_http_connection for an explanation
conn_is_pending_or_h2 = (
connection.alpn == b"h2"
or connection in self.waiting_for_establishment
)
h2_to_h1 = self.context.client.alpn == b"h2" and not conn_is_pending_or_h2
connection_suitable = ( connection_suitable = (
event.connection_spec_matches(connection) event.connection_spec_matches(connection)
and not h2_to_h1
) )
if connection_suitable: if connection_suitable:
if connection in self.waiting_for_establishment: if connection in self.waiting_for_establishment:
self.waiting_for_establishment[connection].append(event) self.waiting_for_establishment[connection].append(event)
return return
elif connection.connected: elif connection.error:
stream = self.command_sources.pop(event) stream = self.command_sources.pop(event)
yield from self.event_to_child(stream, GetHttpConnectionCompleted(event, (connection, None))) yield from self.event_to_child(stream, GetHttpConnectionCompleted(event, (None, connection.error)))
return return
elif connection.connected:
# see "tricky multiplexing edge case" in make_http_connection for an explanation
h2_to_h1 = self.context.client.alpn == b"h2" and connection.alpn != b"h2"
if not h2_to_h1:
stream = self.command_sources.pop(event)
yield from self.event_to_child(stream, GetHttpConnectionCompleted(event, (connection, None)))
return
else: else:
pass # the connection is at least half-closed already, we want a new one. pass # the connection is at least half-closed already, we want a new one.
can_use_context_connection = ( context_connection_matches = (
self.context.server not in self.connections and self.context.server not in self.connections and
self.context.server.connected and
event.connection_spec_matches(self.context.server) event.connection_spec_matches(self.context.server)
) )
can_use_context_connection = (
context_connection_matches
and self.context.server.connected
)
if context_connection_matches and self.context.server.error:
stream = self.command_sources.pop(event)
yield from self.event_to_child(stream, GetHttpConnectionCompleted(event, (None, self.context.server.error)))
return
context = self.context.fork() context = self.context.fork()
stack = tunnel.LayerStack() stack = tunnel.LayerStack()

View File

@ -112,8 +112,15 @@ class BufferedH2Connection(h2.connection.H2Connection):
The window for a specific stream has updated. Send as much buffered data as possible. The window for a specific stream has updated. Send as much buffered data as possible.
""" """
# If the stream has been reset in the meantime, we just clear the buffer. # If the stream has been reset in the meantime, we just clear the buffer.
stream: h2.stream.H2Stream = self.streams[stream_id] try:
if stream.state_machine.state not in (h2.stream.StreamState.OPEN, h2.stream.StreamState.HALF_CLOSED_REMOTE): stream: h2.stream.H2Stream = self.streams[stream_id]
except KeyError:
stream_was_reset = True
else:
stream_was_reset = (
stream.state_machine.state not in (h2.stream.StreamState.OPEN, h2.stream.StreamState.HALF_CLOSED_REMOTE)
)
if stream_was_reset:
self.stream_buffers.pop(stream_id, None) self.stream_buffers.pop(stream_id, None)
return False return False

View File

@ -23,7 +23,7 @@ class DestinationKnown(layer.Layer, metaclass=ABCMeta):
child_layer: layer.Layer child_layer: layer.Layer
def finish_start(self) -> layer.CommandGenerator[Optional[str]]: def finish_start(self) -> layer.CommandGenerator[Optional[str]]:
if self.context.options.connection_strategy == "eager": if self.context.options.connection_strategy == "eager" and self.context.server.address:
err = yield commands.OpenConnection(self.context.server) err = yield commands.OpenConnection(self.context.server)
if err: if err:
self._handle_event = self.done # type: ignore self._handle_event = self.done # type: ignore

View File

@ -193,8 +193,9 @@ class _TLSLayer(tunnel.TunnelLayer):
elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii(): elif last_err == ('SSL routines', 'ssl3_get_record', 'wrong version number') and data[:4].isascii():
err = f"The remote server does not speak TLS." err = f"The remote server does not speak TLS."
else: # pragma: no cover else: # pragma: no cover
# TODO: Add test case one we find one. # TODO: Add test case once we find one.
err = f"OpenSSL {e!r}" err = f"OpenSSL {e!r}"
self.conn.error = err
return False, err return False, err
else: else:
# Here we set all attributes that are only known *after* the handshake. # Here we set all attributes that are only known *after* the handshake.
@ -254,7 +255,11 @@ class _TLSLayer(tunnel.TunnelLayer):
yield from super().receive_close() yield from super().receive_close()
def send_data(self, data: bytes) -> layer.CommandGenerator[None]: def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
self.tls.sendall(data) try:
self.tls.sendall(data)
except SSL.ZeroReturnError:
# The other peer may still be trying to send data over, which we discard here.
pass
yield from self.tls_interact() yield from self.tls_interact()
def send_close(self, half_close: bool) -> layer.CommandGenerator[None]: def send_close(self, half_close: bool) -> layer.CommandGenerator[None]:
@ -266,14 +271,39 @@ class ServerTLSLayer(_TLSLayer):
""" """
This layer establishes TLS for a single server connection. This layer establishes TLS for a single server connection.
""" """
command_to_reply_to: Optional[commands.OpenConnection] = None wait_for_clienthello: bool = False
def __init__(self, context: context.Context, conn: Optional[connection.Server] = None): def __init__(self, context: context.Context, conn: Optional[connection.Server] = None):
super().__init__(context, conn or context.server) super().__init__(context, conn or context.server)
def start_handshake(self) -> layer.CommandGenerator[None]: def start_handshake(self) -> layer.CommandGenerator[None]:
yield from self.start_tls() wait_for_clienthello = (
yield from self.receive_handshake_data(b"") # if command_to_reply_to is set, we've been instructed to open the connection from the child layer.
# in that case any potential ClientHello is already parsed (by the ClientTLS child layer).
not self.command_to_reply_to
# if command_to_reply_to is not set, the connection was already open when this layer received its Start
# event (eager connection strategy). We now want to establish TLS right away, _unless_ we already know
# that there's TLS on the client side as well (we check if our immediate child layer is set to be ClientTLS)
# In this case want to wait for ClientHello to be parsed, so that we can incorporate SNI/ALPN from there.
and isinstance(self.child_layer, ClientTLSLayer)
)
if wait_for_clienthello:
self.wait_for_clienthello = True
self.tunnel_state = tunnel.TunnelState.CLOSED
else:
yield from self.start_tls()
yield from self.receive_handshake_data(b"")
def event_to_child(self, event: events.Event) -> layer.CommandGenerator[None]:
if self.wait_for_clienthello:
for command in super().event_to_child(event):
if isinstance(command, commands.OpenConnection) and command.connection == self.conn:
self.wait_for_clienthello = False
# swallow OpenConnection here by not re-yielding it.
else:
yield command
else:
yield from super().event_to_child(event)
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]: def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
yield commands.Log(f"Server TLS handshake failed. {err}", level="warn") yield commands.Log(f"Server TLS handshake failed. {err}", level="warn")

View File

@ -75,7 +75,7 @@ class TunnelLayer(layer.Layer):
if self.tunnel_state is TunnelState.OPEN: if self.tunnel_state is TunnelState.OPEN:
yield from self.receive_close() yield from self.receive_close()
elif self.tunnel_state is TunnelState.ESTABLISHING: elif self.tunnel_state is TunnelState.ESTABLISHING:
err = "connection closed without notice" err = "connection closed"
yield from self.on_handshake_error(err) yield from self.on_handshake_error(err)
yield from self._handshake_finished(err) yield from self._handshake_finished(err)
self.tunnel_state = TunnelState.CLOSED self.tunnel_state = TunnelState.CLOSED

View File

@ -97,6 +97,10 @@ class TestNextLayer:
tctx.configure(nl, ignore_hosts=[]) tctx.configure(nl, ignore_hosts=[])
assert isinstance(nl._next_layer(ctx, client_hello_no_extensions, b""), layers.ServerTLSLayer) assert isinstance(nl._next_layer(ctx, client_hello_no_extensions, b""), layers.ServerTLSLayer)
assert isinstance(ctx.layers[-1], layers.ClientTLSLayer)
ctx.layers = []
assert isinstance(nl._next_layer(ctx, b"", b""), layers.modes.HttpProxy)
assert isinstance(nl._next_layer(ctx, client_hello_no_extensions, b""), layers.ClientTLSLayer) assert isinstance(nl._next_layer(ctx, client_hello_no_extensions, b""), layers.ClientTLSLayer)
ctx.layers = [] ctx.layers = []

View File

@ -30,6 +30,10 @@ def test_alpn_select_callback():
# Test no overlap # Test no overlap
assert tlsconfig.alpn_select_callback(conn, [b"qux", b"quux"]) == SSL.NO_OVERLAPPING_PROTOCOLS assert tlsconfig.alpn_select_callback(conn, [b"qux", b"quux"]) == SSL.NO_OVERLAPPING_PROTOCOLS
# Test that we don't select an ALPN if the server refused to select one.
conn.set_app_data(tlsconfig.AppData(server_alpn=b"", http2=True))
assert tlsconfig.alpn_select_callback(conn, [b"http/1.1"]) == SSL.NO_OVERLAPPING_PROTOCOLS
here = Path(__file__).parent here = Path(__file__).parent

View File

@ -546,6 +546,7 @@ def test_http_proxy_tcp(tctx, mode, close_first):
"""Test TCP over HTTP CONNECT.""" """Test TCP over HTTP CONNECT."""
server = Placeholder(Server) server = Placeholder(Server)
f = Placeholder(TCPFlow) f = Placeholder(TCPFlow)
tctx.options.connection_strategy = "lazy"
if mode == "upstream": if mode == "upstream":
tctx.options.mode = "upstream:http://proxy:8080" tctx.options.mode = "upstream:http://proxy:8080"
@ -813,6 +814,7 @@ def test_http_server_aborts(tctx, stream):
"response", "error"]) "response", "error"])
def test_kill_flow(tctx, when): def test_kill_flow(tctx, when):
"""Test that we properly kill flows if instructed to do so""" """Test that we properly kill flows if instructed to do so"""
tctx.options.connection_strategy = "lazy"
server = Placeholder(Server) server = Placeholder(Server)
connect_flow = Placeholder(HTTPFlow) connect_flow = Placeholder(HTTPFlow)
flow = Placeholder(HTTPFlow) flow = Placeholder(HTTPFlow)
@ -1000,3 +1002,18 @@ def test_dont_reuse_closed(tctx):
>> DataReceived(server2, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") >> DataReceived(server2, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n") << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
) )
def test_reuse_error(tctx):
"""Test that an errored connection is reused."""
tctx.server.address = ("example.com", 443)
tctx.server.error = "tls verify failed"
error_html = Placeholder(bytes)
assert (
Playbook(http.HttpLayer(tctx, HTTPMode.transparent), hooks=False)
>> DataReceived(tctx.client, b"GET / HTTP/1.1\r\n\r\n")
<< SendData(tctx.client, error_html)
<< CloseConnection(tctx.client)
)
assert b"502 Bad Gateway" in error_html()
assert b"tls verify failed" in error_html()

View File

@ -119,6 +119,7 @@ def test_reverse_proxy(tctx, keep_host_header):
""" """
server = Placeholder(Server) server = Placeholder(Server)
tctx.options.mode = "reverse:http://localhost:8000" tctx.options.mode = "reverse:http://localhost:8000"
tctx.options.connection_strategy = "lazy"
tctx.options.keep_host_header = keep_host_header tctx.options.keep_host_header = keep_host_header
assert ( assert (
Playbook(modes.ReverseProxy(tctx), hooks=False) Playbook(modes.ReverseProxy(tctx), hooks=False)
@ -321,6 +322,7 @@ def test_socks5_success(address: str, packed: bytes, tctx: Context):
def test_socks5_trickle(tctx: Context): def test_socks5_trickle(tctx: Context):
tctx.options.connection_strategy = "lazy"
playbook = Playbook(modes.Socks5Proxy(tctx)) playbook = Playbook(modes.Socks5Proxy(tctx))
for x in CLIENT_HELLO: for x in CLIENT_HELLO:
playbook >> DataReceived(tctx.client, bytes([x])) playbook >> DataReceived(tctx.client, bytes([x]))

View File

@ -391,11 +391,14 @@ class TestClientTLS:
<< commands.SendData(other_server, b"plaintext") << commands.SendData(other_server, b"plaintext")
) )
def test_server_required(self, tctx): @pytest.mark.parametrize("eager", ["eager", ""])
def test_server_required(self, tctx, eager):
""" """
Test the scenario where a server connection is required (for example, because of an unknown ALPN) Test the scenario where a server connection is required (for example, because of an unknown ALPN)
to establish TLS with the client. to establish TLS with the client.
""" """
if eager:
tctx.server.state = ConnectionState.OPEN
tssl_server = SSLTest(server_side=True, alpn=["quux"]) tssl_server = SSLTest(server_side=True, alpn=["quux"])
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, alpn=["quux"]) playbook, client_layer, tssl_client = make_client_tls_layer(tctx, alpn=["quux"])
@ -405,16 +408,23 @@ class TestClientTLS:
def require_server_conn(client_hello: tls.ClientHelloData) -> None: def require_server_conn(client_hello: tls.ClientHelloData) -> None:
client_hello.establish_server_tls_first = True client_hello.establish_server_tls_first = True
assert ( (
playbook playbook
>> events.DataReceived(tctx.client, tssl_client.bio_read()) >> events.DataReceived(tctx.client, tssl_client.bio_read())
<< tls.TlsClienthelloHook(tutils.Placeholder()) << tls.TlsClienthelloHook(tutils.Placeholder())
>> tutils.reply(side_effect=require_server_conn) >> tutils.reply(side_effect=require_server_conn)
)
if not eager:
(
playbook
<< commands.OpenConnection(tctx.server) << commands.OpenConnection(tctx.server)
>> tutils.reply(None) >> tutils.reply(None)
<< tls.TlsStartHook(tutils.Placeholder()) )
>> reply_tls_start(alpn=b"quux") assert (
<< commands.SendData(tctx.server, data) playbook
<< tls.TlsStartHook(tutils.Placeholder())
>> reply_tls_start(alpn=b"quux")
<< commands.SendData(tctx.server, data)
) )
# Establish TLS with the server... # Establish TLS with the server...
@ -509,6 +519,6 @@ class TestClientTLS:
<< commands.SendData(tctx.client, tutils.Placeholder()) << commands.SendData(tctx.client, tutils.Placeholder())
>> events.ConnectionClosed(tctx.client) >> events.ConnectionClosed(tctx.client)
<< commands.Log("Client TLS handshake failed. The client may not trust the proxy's certificate " << commands.Log("Client TLS handshake failed. The client may not trust the proxy's certificate "
"for wrong.host.mitmproxy.org (connection closed without notice)", "warn") "for wrong.host.mitmproxy.org (connection closed)", "warn")
<< commands.CloseConnection(tctx.client) << commands.CloseConnection(tctx.client)
) )

View File

@ -52,7 +52,7 @@ class TestLayer:
<< commands.Log(" >! DataReceived(client, b'foo')", "debug") << commands.Log(" >! DataReceived(client, b'foo')", "debug")
>> tutils.reply(None, to=-3) >> tutils.reply(None, to=-3)
<< commands.Log(" >> Reply(OpenConnection({'connection': Server(" << commands.Log(" >> Reply(OpenConnection({'connection': Server("
"{'id': '…rverid', 'address': None, 'state': <ConnectionState.OPEN: 3>})}))", "debug") "{'id': '…rverid', 'address': None, 'state': <ConnectionState.OPEN: 3>})}),None)", "debug")
<< commands.Log(" !> DataReceived(client, b'foo')", "debug") << commands.Log(" !> DataReceived(client, b'foo')", "debug")
<< commands.Log("baz", "info") << commands.Log("baz", "info")

View File

@ -245,7 +245,7 @@ def test_disconnect_during_handshake_command(tctx: Context, disconnect):
>> ConnectionClosed(tctx.client) >> ConnectionClosed(tctx.client)
>> ConnectionClosed(server) # proxyserver will cancel all other connections as well. >> ConnectionClosed(server) # proxyserver will cancel all other connections as well.
<< CloseConnection(server) << CloseConnection(server)
<< Log("Opened: err='connection closed without notice'. Server state: CLOSED") << Log("Opened: err='connection closed'. Server state: CLOSED")
<< Log("Got client close.") << Log("Got client close.")
<< CloseConnection(tctx.client) << CloseConnection(tctx.client)
) )
@ -254,7 +254,7 @@ def test_disconnect_during_handshake_command(tctx: Context, disconnect):
playbook playbook
>> ConnectionClosed(server) >> ConnectionClosed(server)
<< CloseConnection(server) << CloseConnection(server)
<< Log("Opened: err='connection closed without notice'. Server state: CLOSED") << Log("Opened: err='connection closed'. Server state: CLOSED")
) )