diff --git a/libmproxy/protocol/http2.py b/libmproxy/protocol/http2.py index 54e7572e2..71423bf7f 100644 --- a/libmproxy/protocol/http2.py +++ b/libmproxy/protocol/http2.py @@ -17,7 +17,7 @@ from .base import Layer from .http import _HttpTransmissionLayer, HttpLayer from .. import utils from ..models import HTTPRequest, HTTPResponse - +from ..exceptions import HttpProtocolException, ProtocolException class SafeH2Connection(H2Connection): def __init__(self, conn, *args, **kwargs): @@ -207,7 +207,14 @@ class Http2Layer(Layer): is_server = (conn == self.server_conn.connection) with source_conn.h2.lock: - events = source_conn.h2.receive_data(utils.http2_read_frame(source_conn.rfile)) + try: + raw_frame = utils.http2_read_frame(source_conn.rfile) + except: + for stream in self.streams.values(): + stream.zombie = time.time() + return + + events = source_conn.h2.receive_data(raw_frame) source_conn.send(source_conn.h2.data_to_send()) for event in events: diff --git a/test/test_protocol_http2.py b/test/test_protocol_http2.py index e72113c46..4fa2c7013 100644 --- a/test/test_protocol_http2.py +++ b/test/test_protocol_http2.py @@ -28,9 +28,7 @@ requires_alpn = pytest.mark.skipif( class SimpleHttp2Server(netlib_tservers.ServerTestBase): - ssl = dict( - alpn_select=b'h2', - ) + ssl = dict(alpn_select=b'h2') class handler(netlib.tcp.BaseHandler): def handle(self): @@ -61,6 +59,59 @@ class SimpleHttp2Server(netlib_tservers.ServerTestBase): return +class PushHttp2Server(netlib_tservers.ServerTestBase): + ssl = dict(alpn_select=b'h2') + + class handler(netlib.tcp.BaseHandler): + def handle(self): + h2_conn = h2.connection.H2Connection(client_side=False) + + preamble = self.rfile.read(24) + h2_conn.initiate_connection() + h2_conn.receive_data(preamble) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + while True: + events = h2_conn.receive_data(utils.http2_read_frame(self.rfile)) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + for event in events: + if isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(1, [(':status', '200')]) + h2_conn.push_stream(1, 2, [ + (':authority', "127.0.0.1:%s" % self.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_foo'), + ('foo', 'bar') + ]) + h2_conn.push_stream(1, 4, [ + (':authority', "127.0.0.1:%s" % self.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/pushed_stream_bar'), + ('foo', 'bar') + ]) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + + h2_conn.send_headers(2, [(':status', '202')]) + h2_conn.send_headers(4, [(':status', '204')]) + h2_conn.send_data(1, b'regular_stream') + h2_conn.send_data(2, b'pushed_stream_foo') + h2_conn.send_data(4, b'pushed_stream_bar') + h2_conn.end_stream(1) + h2_conn.end_stream(2) + h2_conn.end_stream(4) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + print("HERE") + elif isinstance(event, h2.events.ConnectionTerminated): + return + + @requires_alpn class TestHttp2(tservers.ProxTestBase): def _setup_connection(self): @@ -132,3 +183,37 @@ class TestHttp2(tservers.ProxTestBase): assert self.master.state.flows[0].response.status_code == 200 assert self.master.state.flows[0].response.headers['foo'] == 'bar' assert self.master.state.flows[0].response.body == b'foobar' + + def test_pushed_streams(self): + self.server = PushHttp2Server() + self.server.setup_class() + + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, headers=[ + (':authority', "127.0.0.1:%s" % self.server.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('foo', 'bar') + ]) + + ended_streams = 0 + while ended_streams != 3: + try: + events = h2_conn.receive_data(utils.http2_read_frame(client.rfile)) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + + self.server.teardown_class() + + assert len(self.master.state.flows) == 3 + assert self.master.state.flows[0].response.body == b'regular_stream' + assert self.master.state.flows[1].response.body == b'pushed_stream_foo' + assert self.master.state.flows[2].response.body == b'pushed_stream_bar'