diff --git a/mitmproxy/proxy2/server.py b/mitmproxy/proxy2/server.py index 55217d472..281938952 100644 --- a/mitmproxy/proxy2/server.py +++ b/mitmproxy/proxy2/server.py @@ -233,7 +233,6 @@ class ConnectionHandler(metaclass=abc.ABCMeta): async def on_timeout(self) -> None: self.log(f"Closing connection due to inactivity: {self.client}") - self.client.state = ConnectionState.CLOSED cancel_task(self.transports[self.client].handler, "timeout") async def hook_task(self, hook: commands.Hook) -> None: @@ -295,7 +294,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta): connection.state = ConnectionState.CLOSED if connection.state is ConnectionState.CLOSED: - cancel_task(self.transports[connection].handler, "closed by proxy") + cancel_task(self.transports[connection].handler, "closed by command") class StreamConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta): diff --git a/test/mitmproxy/proxy2/test_layer.py b/test/mitmproxy/proxy2/test_layer.py index 4162bb223..8dc46cef1 100644 --- a/test/mitmproxy/proxy2/test_layer.py +++ b/test/mitmproxy/proxy2/test_layer.py @@ -1,4 +1,6 @@ -from mitmproxy.proxy2 import layer, events, commands +import pytest + +from mitmproxy.proxy2 import commands, events, layer from test.mitmproxy.proxy2 import tutils @@ -50,6 +52,31 @@ class TestNextLayer: << commands.SendData(tctx.client, b"bar") ) + @pytest.mark.parametrize("layer_found", [True, False]) + def test_receive_close(self, tctx, layer_found): + """Test that we abort a client connection which has disconnected without any layer being found.""" + nl = layer.NextLayer(tctx) + playbook = tutils.Playbook(nl) + assert ( + playbook + >> events.DataReceived(tctx.client, b"foo") + << layer.NextLayerHook(nl) + >> events.ConnectionClosed(tctx.client) + ) + if layer_found: + nl.layer = tutils.RecordLayer(tctx) + assert ( + playbook + >> tutils.reply(to=-2) + ) + assert isinstance(nl.layer.event_log[-1], events.ConnectionClosed) + else: + assert ( + playbook + >> tutils.reply(to=-2) + << commands.CloseConnection(tctx.client) + ) + def test_func_references(self, tctx): nl = layer.NextLayer(tctx) playbook = tutils.Playbook(nl) diff --git a/test/mitmproxy/proxy2/tutils.py b/test/mitmproxy/proxy2/tutils.py index 71f7936e4..b3da9c03c 100644 --- a/test/mitmproxy/proxy2/tutils.py +++ b/test/mitmproxy/proxy2/tutils.py @@ -356,6 +356,19 @@ class EchoLayer(Layer): yield commands.CloseConnection(event.connection) +class RecordLayer(Layer): + """Layer that records all events but does nothing.""" + event_log: typing.List[events.Event] + + def __init__(self, context: context.Context) -> None: + super().__init__(context) + self.event_log = [] + + def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: + self.event_log.append(event) + yield from () + + def reply_next_layer( child_layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]], *args,