diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index 5da91ac24..c8aaed8ab 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -180,7 +180,6 @@ class Http2Layer(base.Layer): headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers) self.streams[eid].timestamp_start = time.time() - self.streams[eid].no_request_body = (event.stream_ended is not None) if event.priority_updated is not None: self.streams[eid].priority_exclusive = event.priority_updated.exclusive self.streams[eid].priority_depends_on = event.priority_updated.depends_on @@ -424,8 +423,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr self.request_message = self.Message(request_headers) self.response_message = self.Message() - self.no_request_body = False - self.priority_exclusive: bool self.priority_depends_on: Optional[int] = None self.priority_weight: Optional[int] = None @@ -594,7 +591,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr self.raise_zombie, self.server_stream_id, headers, - end_stream=self.no_request_body, priority_exclusive=priority_exclusive, priority_depends_on=priority_depends_on, priority_weight=priority_weight, @@ -611,12 +607,12 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr # nothing to do here return - if not self.no_request_body: - self.connections[self.server_conn].safe_send_body( - self.raise_zombie, - self.server_stream_id, - chunks - ) + self.connections[self.server_conn].safe_send_body( + self.raise_zombie, + self.server_stream_id, + chunks, + end_stream=(request.trailers is None), + ) @detect_zombie_stream def send_request_trailers(self, request): @@ -683,7 +679,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr self.raise_zombie, self.client_stream_id, chunks, - end_stream=("trailer" not in response.headers) + end_stream=(response.trailers is None), ) @detect_zombie_stream diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index 1529e7317..ba1070102 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -1033,29 +1033,110 @@ class TestResponseStreaming(_Http2Test): assert data is None -class TestTrailers(_Http2Test): +class TestRequestTrailers(_Http2Test): + server_trailers_received = False + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.RequestReceived): + # reset the value for a fresh test + cls.server_trailers_received = False + elif isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.TrailersReceived): + cls.server_trailers_received = True + + elif isinstance(event, h2.events.StreamEnded): + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('x-my-trailer-request-received', 'success' if cls.server_trailers_received else "failure"), + ], end_stream=True) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + @pytest.mark.parametrize('announce', [True, False]) + @pytest.mark.parametrize('body', [None, b"foobar"]) + def test_trailers(self, announce, body): + h2_conn = self.setup_connection() + stream_id = 1 + headers = [ + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ] + if announce: + headers.append(('trailer', 'x-my-trailers')) + h2_conn.send_headers( + stream_id=stream_id, + headers=headers, + ) + if body: + h2_conn.send_data(stream_id, body) + + # send trailers + h2_conn.send_headers(stream_id, [('x-my-trailers', 'foobar')], end_stream=True) + + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() + + done = False + while not done: + try: + raw = b''.join(http2.read_raw_frame(self.client.rfile)) + events = h2_conn.receive_data(raw) + except exceptions.HttpException: + print(traceback.format_exc()) + assert False + + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].request.trailers['x-my-trailers'] == 'foobar' + assert self.master.state.flows[0].response.status_code == 200 + assert self.master.state.flows[0].response.headers['x-my-trailer-request-received'] == 'success' + + +class TestResponseTrailers(_Http2Test): + @classmethod def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.StreamEnded): - h2_conn.send_headers(event.stream_id, [ + headers = [ (':status', '200'), - ('trailer', 'x-my-trailers') - ]) + ] + if event.stream_id == 1: + # special stream_id to activate the Trailer announcement header + headers.append(('trailer', 'x-my-trailers')) + + h2_conn.send_headers(event.stream_id, headers) h2_conn.send_data(event.stream_id, b'response body') h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True) wfile.write(h2_conn.data_to_send()) wfile.flush() return True - def test_trailers(self): + @pytest.mark.parametrize('announce', [True, False]) + def test_trailers(self, announce): response_body_buffer = b'' h2_conn = self.setup_connection() self._send_request( self.client.wfile, h2_conn, + stream_id=(1 if announce else 3), headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), @@ -1092,5 +1173,5 @@ class TestTrailers(_Http2Test): assert self.master.state.flows[0].response.status_code == 200 assert self.master.state.flows[0].response.content == b'response body' assert response_body_buffer == b'response body' - assert self.master.state.flows[0].response.data.trailers['x-my-trailers'] == 'foobar' + assert self.master.state.flows[0].response.trailers['x-my-trailers'] == 'foobar' assert trailers_buffer == [(b'x-my-trailers', b'foobar')]