[sans-io] test NextLayer behavior on connection close

This commit is contained in:
Maximilian Hils 2020-11-21 14:01:45 +01:00
parent 65870b729f
commit b5f59a1297
3 changed files with 42 additions and 3 deletions

View File

@ -233,7 +233,6 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
self.client.state = ConnectionState.CLOSED
cancel_task(self.transports[self.client].handler, "timeout")
async def hook_task(self, hook: commands.Hook) -> None:
@ -295,7 +294,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
connection.state = ConnectionState.CLOSED
if connection.state is ConnectionState.CLOSED:
cancel_task(self.transports[connection].handler, "closed by proxy")
cancel_task(self.transports[connection].handler, "closed by command")
class StreamConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):

View File

@ -1,4 +1,6 @@
from mitmproxy.proxy2 import layer, events, commands
import pytest
from mitmproxy.proxy2 import commands, events, layer
from test.mitmproxy.proxy2 import tutils
@ -50,6 +52,31 @@ class TestNextLayer:
<< commands.SendData(tctx.client, b"bar")
)
@pytest.mark.parametrize("layer_found", [True, False])
def test_receive_close(self, tctx, layer_found):
"""Test that we abort a client connection which has disconnected without any layer being found."""
nl = layer.NextLayer(tctx)
playbook = tutils.Playbook(nl)
assert (
playbook
>> events.DataReceived(tctx.client, b"foo")
<< layer.NextLayerHook(nl)
>> events.ConnectionClosed(tctx.client)
)
if layer_found:
nl.layer = tutils.RecordLayer(tctx)
assert (
playbook
>> tutils.reply(to=-2)
)
assert isinstance(nl.layer.event_log[-1], events.ConnectionClosed)
else:
assert (
playbook
>> tutils.reply(to=-2)
<< commands.CloseConnection(tctx.client)
)
def test_func_references(self, tctx):
nl = layer.NextLayer(tctx)
playbook = tutils.Playbook(nl)

View File

@ -356,6 +356,19 @@ class EchoLayer(Layer):
yield commands.CloseConnection(event.connection)
class RecordLayer(Layer):
"""Layer that records all events but does nothing."""
event_log: typing.List[events.Event]
def __init__(self, context: context.Context) -> None:
super().__init__(context)
self.event_log = []
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
self.event_log.append(event)
yield from ()
def reply_next_layer(
child_layer: typing.Union[typing.Type[Layer], typing.Callable[[context.Context], Layer]],
*args,