diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index 19c7c6041..9515eef93 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -478,6 +478,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def send_response_headers(self, response): headers = response.headers.copy() headers.insert(0, ":status", str(response.status_code)) + for forbidden_header in h2.utilities.CONNECTION_HEADERS: + if forbidden_header in headers: + del headers[forbidden_header] with self.client_conn.h2.lock: self.client_conn.h2.safe_send_headers( self.is_zombie, diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py index 932c8df2e..2eb0b120d 100644 --- a/test/mitmproxy/test_protocol_http2.py +++ b/test/mitmproxy/test_protocol_http2.py @@ -3,9 +3,10 @@ from __future__ import (absolute_import, print_function, division) import pytest -import traceback import os import tempfile +import traceback + import h2 from mitmproxy.proxy.config import ProxyConfig @@ -46,6 +47,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): self.wfile.write(h2_conn.data_to_send()) self.wfile.flush() + if 'h2_server_settings' in self.kwargs: + h2_conn.update_settings(self.kwargs['h2_server_settings']) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + done = False while not done: try: @@ -508,3 +514,120 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase): if len(self.master.state.flows) == 1: assert self.master.state.flows[0].response is None + + +@requires_alpn +class TestMaxConcurrentStreams(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2}) + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('X-Stream-ID', str(event.stream_id)), + ]) + h2_conn.send_data(event.stream_id, b'Stream-ID {}'.format(event.stream_id)) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_max_concurrent_streams(self): + client, h2_conn = self._setup_connection() + new_streams = [1, 3, 5, 7, 9, 11] + for id in new_streams: + # this will exceed MAX_CONCURRENT_STREAMS on the server connection + # and cause mitmproxy to throttle stream creation to the server + self._send_request(client.wfile, h2_conn, stream_id=id, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('X-Stream-ID', str(id)), + ]) + + ended_streams = 0 + while ended_streams != len(new_streams): + try: + header, body = framereader.http2_read_raw_frame(client.rfile) + events = h2_conn.receive_data(b''.join([header, body])) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == len(new_streams) + for flow in self.master.state.flows: + assert flow.response.status_code == 200 + assert "Stream-ID" in flow.response.body + + +@requires_alpn +class TestConnectionTerminated(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.RequestReceived): + h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data='foobar') + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_connection_terminated(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ]) + + done = False + connection_terminated_event = None + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + for event in events: + if isinstance(event, h2.events.ConnectionTerminated): + connection_terminated_event = event + done = True + except: + break + + assert len(self.master.state.flows) == 1 + assert connection_terminated_event is not None + assert connection_terminated_event.error_code == 5 + assert connection_terminated_event.last_stream_id == 42 + assert connection_terminated_event.additional_data == 'foobar' diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py index 803aaa727..666f97ac8 100644 --- a/test/netlib/tservers.py +++ b/test/netlib/tservers.py @@ -24,7 +24,7 @@ class _ServerThread(threading.Thread): class _TServer(tcp.TCPServer): - def __init__(self, ssl, q, handler_klass, addr): + def __init__(self, ssl, q, handler_klass, addr, **kwargs): """ ssl: A dictionary of SSL parameters: @@ -42,6 +42,8 @@ class _TServer(tcp.TCPServer): self.q = q self.handler_klass = handler_klass + if self.handler_klass is not None: + self.handler_klass.kwargs = kwargs self.last_handler = None def handle_client_connection(self, request, client_address): @@ -89,16 +91,16 @@ class ServerTestBase(object): addr = ("localhost", 0) @classmethod - def setup_class(cls): + def setup_class(cls, **kwargs): cls.q = queue.Queue() - s = cls.makeserver() + s = cls.makeserver(**kwargs) cls.port = s.address.port cls.server = _ServerThread(s) cls.server.start() @classmethod - def makeserver(cls): - return _TServer(cls.ssl, cls.q, cls.handler, cls.addr) + def makeserver(cls, **kwargs): + return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs) @classmethod def teardown_class(cls):