From 618044c637318749c8bba86d73dd6191b02e24d9 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 24 May 2017 11:43:50 +0200 Subject: [PATCH] http2 tests: fix leaking sockets --- test/mitmproxy/proxy/protocol/test_http2.py | 172 ++++++++++---------- 1 file changed, 88 insertions(+), 84 deletions(-) diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index b07257b3b..261f8415c 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -118,12 +118,16 @@ class _Http2TestBase: self.master.reset([]) self.server.server.handle_server_event = self.handle_server_event - def _setup_connection(self): - client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port)) - client.connect() + def teardown(self): + if self.client: + self.client.close() + + def setup_connection(self): + self.client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + self.client.connect() # send CONNECT request - client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request( + self.client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request( 'authority', b'CONNECT', b'', @@ -134,13 +138,13 @@ class _Http2TestBase: [(b'host', b'localhost:%d' % self.server.server.address[1])], b'', ))) - client.wfile.flush() + self.client.wfile.flush() # read CONNECT response - while client.rfile.readline() != b"\r\n": + while self.client.rfile.readline() != b"\r\n": pass - client.convert_to_ssl(alpn_protos=[b'h2']) + self.client.convert_to_ssl(alpn_protos=[b'h2']) config = h2.config.H2Configuration( client_side=True, @@ -148,10 +152,10 @@ class _Http2TestBase: validate_inbound_headers=False) h2_conn = h2.connection.H2Connection(config) h2_conn.initiate_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() - return client, h2_conn + return h2_conn def _send_request(self, wfile, @@ -205,8 +209,8 @@ class TestSimple(_Http2Test): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.RequestReceived): - assert (b'client-foo', b'client-bar-1') in event.headers - assert (b'client-foo', b'client-bar-2') in event.headers + assert (b'self.client-foo', b'self.client-bar-1') in event.headers + assert (b'self.client-foo', b'self.client-bar-2') in event.headers elif isinstance(event, h2.events.StreamEnded): import warnings with warnings.catch_warnings(): @@ -233,32 +237,32 @@ class TestSimple(_Http2Test): def test_simple(self): response_body_buffer = b'' - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), - ('ClIeNt-FoO', 'client-bar-1'), - ('ClIeNt-FoO', 'client-bar-2'), + ('self.client-FoO', 'self.client-bar-1'), + ('self.client-FoO', 'self.client-bar-2'), ], body=b'request body') done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.DataReceived): @@ -267,8 +271,8 @@ class TestSimple(_Http2Test): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 assert self.master.state.flows[0].response.status_code == 200 @@ -317,10 +321,10 @@ class TestRequestWithPriority(_Http2Test): def test_request_with_priority(self, http2_priority_enabled, priority, expected_priority): self.config.options.http2_priority = http2_priority_enabled - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -336,22 +340,22 @@ class TestRequestWithPriority(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 @@ -397,15 +401,15 @@ class TestPriority(_Http2Test): self.config.options.http2_priority = http2_priority_enabled self.__class__.priority_data = [] - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() if prioritize_before: h2_conn.prioritize(1, exclusive=priority[0], depends_on=priority[1], weight=priority[2]) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -419,28 +423,28 @@ class TestPriority(_Http2Test): if not prioritize_before: h2_conn.prioritize(1, exclusive=priority[0], depends_on=priority[1], weight=priority[2]) h2_conn.end_stream(1) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 assert self.priority_data == expected_priority @@ -460,10 +464,10 @@ class TestStreamResetFromServer(_Http2Test): return True def test_request_with_priority(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -476,22 +480,22 @@ class TestStreamResetFromServer(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamReset): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 assert self.master.state.flows[0].response is None @@ -510,10 +514,10 @@ class TestBodySizeLimit(_Http2Test): self.config.options.body_size_limit = "20" self.config.options._processed["body_size_limit"] = 20 - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -527,22 +531,22 @@ class TestBodySizeLimit(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamReset): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 0 @@ -609,9 +613,9 @@ class TestPushPromise(_Http2Test): return True def test_push_promise(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -625,15 +629,15 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False except: break - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): @@ -649,8 +653,8 @@ class TestPushPromise(_Http2Test): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert ended_streams == 3 assert pushed_streams == 2 @@ -665,9 +669,9 @@ class TestPushPromise(_Http2Test): assert len(pushed_flows) == 2 def test_push_promise_reset(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -681,14 +685,14 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: @@ -696,8 +700,8 @@ class TestPushPromise(_Http2Test): elif isinstance(event, h2.events.PushedStreamReceived): pushed_streams += 1 h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() elif isinstance(event, h2.events.ResponseReceived): responses += 1 if isinstance(event, h2.events.ConnectionTerminated): @@ -707,8 +711,8 @@ class TestPushPromise(_Http2Test): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() bodies = [flow.response.content for flow in self.master.state.flows if flow.response] assert len(bodies) >= 1 @@ -728,9 +732,9 @@ class TestConnectionLost(_Http2Test): return False def test_connection_lost(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -741,7 +745,7 @@ class TestConnectionLost(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) @@ -749,8 +753,8 @@ class TestConnectionLost(_Http2Test): except: break try: - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() except: break @@ -782,12 +786,12 @@ class TestMaxConcurrentStreams(_Http2Test): return True def test_max_concurrent_streams(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() new_streams = [1, 3, 5, 7, 9, 11] for stream_id in new_streams: # this will exceed MAX_CONCURRENT_STREAMS on the server connection # and cause mitmproxy to throttle stream creation to the server - self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=stream_id, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -798,20 +802,20 @@ class TestMaxConcurrentStreams(_Http2Test): ended_streams = 0 while ended_streams != len(new_streams): try: - header, body = http2.read_raw_frame(client.rfile) + header, body = http2.read_raw_frame(self.client.rfile) events = h2_conn.receive_data(b''.join([header, body])) except: break - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): ended_streams += 1 h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == len(new_streams) for flow in self.master.state.flows: @@ -831,9 +835,9 @@ class TestConnectionTerminated(_Http2Test): return True def test_connection_terminated(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, headers=[ + self._send_request(self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -844,7 +848,7 @@ class TestConnectionTerminated(_Http2Test): connection_terminated_event = None while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) for event in events: if isinstance(event, h2.events.ConnectionTerminated):