diff --git a/test/test_protocol_http2.py b/test/test_protocol_http2.py index b42b86cb6..b1c88b274 100644 --- a/test/test_protocol_http2.py +++ b/test/test_protocol_http2.py @@ -235,6 +235,9 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): h2_conn.send_headers(2, [(':status', '202')]) h2_conn.send_headers(4, [(':status', '204')]) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + h2_conn.send_data(1, b'regular_stream') h2_conn.send_data(2, b'pushed_stream_foo') h2_conn.send_data(4, b'pushed_stream_bar') @@ -280,3 +283,37 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): assert b'regular_stream' in bodies assert b'pushed_stream_foo' in bodies assert b'pushed_stream_bar' in bodies + + def test_push_promise_reset(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + done = False + while not done: + try: + events = h2_conn.receive_data(utils.http2_read_frame(client.rfile)) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: + done = True + elif isinstance(event, h2.events.PushedStreamReceived): + h2_conn.reset_stream(event.pushed_stream_id) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + bodies = [flow.response.body for flow in self.master.state.flows] + assert len(bodies) == 3 + assert b'regular_stream' in bodies + assert b'pushed_stream_foo' in bodies + assert b'pushed_stream_bar' in bodies