diff --git a/mitmproxy/proxy/protocol2/tcp.py b/mitmproxy/proxy/protocol2/tcp.py index 82062fc43..44c6bb070 100644 --- a/mitmproxy/proxy/protocol2/tcp.py +++ b/mitmproxy/proxy/protocol2/tcp.py @@ -33,6 +33,8 @@ class TCPLayer(Layer): self.flow.error = flow.Error(err) yield commands.Hook("tcp_error", self.flow) yield commands.CloseConnection(self.context.client) + self._handle_event = self.done + return self._handle_event = self.relay_messages _handle_event = start diff --git a/mitmproxy/proxy/protocol2/test/test_tcp.py b/mitmproxy/proxy/protocol2/test/test_tcp.py index 8329a8d57..58ea0371d 100644 --- a/mitmproxy/proxy/protocol2/test/test_tcp.py +++ b/mitmproxy/proxy/protocol2/test/test_tcp.py @@ -21,6 +21,20 @@ def test_open_connection(tctx): ) +def test_open_connection_err(tctx): + f = tutils.Placeholder() + assert ( + tutils.playbook(tcp.TCPLayer(tctx)) + << commands.Hook("tcp_start", f) + >> events.HookReply(-1, None) + << commands.OpenConnection(tctx.server) + >> events.OpenConnectionReply(-1, "Connect call failed") + << commands.Hook("tcp_error", f) + >> events.HookReply(-1, None) + << commands.CloseConnection(tctx.client) + ) + + def test_simple(tctx): """open connection, receive data, send it to peer""" f = tutils.Placeholder() @@ -34,18 +48,20 @@ def test_simple(tctx): >> events.OpenConnectionReply(-1, None) >> events.DataReceived(tctx.client, b"hello!") << commands.Hook("tcp_message", f) - ) - assert f().messages[0].content == b"hello!" - assert ( - playbook >> events.HookReply(-1, None) << commands.SendData(tctx.server, b"hello!") + >> events.DataReceived(tctx.server, b"hi") + << commands.Hook("tcp_message", f) + >> events.HookReply(-1, None) + << commands.SendData(tctx.client, b"hi") >> events.ConnectionClosed(tctx.server) << commands.CloseConnection(tctx.client) << commands.Hook("tcp_end", f) >> events.HookReply(-1, None) + >> events.ConnectionClosed(tctx.client) << None ) + assert len(f().messages) == 2 def test_simple_explicit(tctx): @@ -65,8 +81,7 @@ def test_simple_explicit(tctx): assert flow.messages[0].content == b"hello!" send, = layer.handle_event(events.HookReply(tcp_msg, None)) - assert tutils._eq(send, commands.SendData(tctx.server, b"hello!s" - b"")) + assert tutils._eq(send, commands.SendData(tctx.server, b"hello!")) close, tcp_end = layer.handle_event(events.ConnectionClosed(tctx.server)) assert tutils._eq(close, commands.CloseConnection(tctx.client)) assert tutils._eq(tcp_end, commands.Hook("tcp_end", flow)) @@ -120,3 +135,18 @@ def test_receive_data_before_server_connected(tctx): >> events.HookReply(-1, None) << commands.SendData(tctx.server, b"hello!") ) + + +def test_receive_data_after_server_disconnected(tctx): + """ + this should just be discarded. + """ + assert ( + tutils.playbook(tcp.TCPLayer(tctx, True)) + << commands.OpenConnection(tctx.server) + >> events.OpenConnectionReply(-1, None) + >> events.ConnectionClosed(tctx.server) + << commands.CloseConnection(tctx.client) + >> events.DataReceived(tctx.client, b"i'm late") + << None + )