fix missing message body and end_stream for trailers

This commit is contained in:
Thomas Kriechbaumer 2020-07-06 12:36:54 +02:00
parent 828ba0c2e7
commit c0f62cc559
2 changed files with 94 additions and 17 deletions

View File

@ -180,7 +180,6 @@ class Http2Layer(base.Layer):
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers) self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers)
self.streams[eid].timestamp_start = time.time() self.streams[eid].timestamp_start = time.time()
self.streams[eid].no_request_body = (event.stream_ended is not None)
if event.priority_updated is not None: if event.priority_updated is not None:
self.streams[eid].priority_exclusive = event.priority_updated.exclusive self.streams[eid].priority_exclusive = event.priority_updated.exclusive
self.streams[eid].priority_depends_on = event.priority_updated.depends_on self.streams[eid].priority_depends_on = event.priority_updated.depends_on
@ -424,8 +423,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.request_message = self.Message(request_headers) self.request_message = self.Message(request_headers)
self.response_message = self.Message() self.response_message = self.Message()
self.no_request_body = False
self.priority_exclusive: bool self.priority_exclusive: bool
self.priority_depends_on: Optional[int] = None self.priority_depends_on: Optional[int] = None
self.priority_weight: Optional[int] = None self.priority_weight: Optional[int] = None
@ -594,7 +591,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.raise_zombie, self.raise_zombie,
self.server_stream_id, self.server_stream_id,
headers, headers,
end_stream=self.no_request_body,
priority_exclusive=priority_exclusive, priority_exclusive=priority_exclusive,
priority_depends_on=priority_depends_on, priority_depends_on=priority_depends_on,
priority_weight=priority_weight, priority_weight=priority_weight,
@ -611,12 +607,12 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
# nothing to do here # nothing to do here
return return
if not self.no_request_body: self.connections[self.server_conn].safe_send_body(
self.connections[self.server_conn].safe_send_body( self.raise_zombie,
self.raise_zombie, self.server_stream_id,
self.server_stream_id, chunks,
chunks end_stream=(request.trailers is None),
) )
@detect_zombie_stream @detect_zombie_stream
def send_request_trailers(self, request): def send_request_trailers(self, request):
@ -683,7 +679,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.raise_zombie, self.raise_zombie,
self.client_stream_id, self.client_stream_id,
chunks, chunks,
end_stream=("trailer" not in response.headers) end_stream=(response.trailers is None),
) )
@detect_zombie_stream @detect_zombie_stream

View File

@ -1033,29 +1033,110 @@ class TestResponseStreaming(_Http2Test):
assert data is None assert data is None
class TestTrailers(_Http2Test): class TestRequestTrailers(_Http2Test):
server_trailers_received = False
@classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.RequestReceived):
# reset the value for a fresh test
cls.server_trailers_received = False
elif isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.TrailersReceived):
cls.server_trailers_received = True
elif isinstance(event, h2.events.StreamEnded):
h2_conn.send_headers(event.stream_id, [
(':status', '200'),
('x-my-trailer-request-received', 'success' if cls.server_trailers_received else "failure"),
], end_stream=True)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
@pytest.mark.parametrize('announce', [True, False])
@pytest.mark.parametrize('body', [None, b"foobar"])
def test_trailers(self, announce, body):
h2_conn = self.setup_connection()
stream_id = 1
headers = [
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
]
if announce:
headers.append(('trailer', 'x-my-trailers'))
h2_conn.send_headers(
stream_id=stream_id,
headers=headers,
)
if body:
h2_conn.send_data(stream_id, body)
# send trailers
h2_conn.send_headers(stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
done = False
while not done:
try:
raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamEnded):
done = True
h2_conn.close_connection()
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].request.trailers['x-my-trailers'] == 'foobar'
assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.headers['x-my-trailer-request-received'] == 'success'
class TestResponseTrailers(_Http2Test):
@classmethod @classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile): def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated): if isinstance(event, h2.events.ConnectionTerminated):
return False return False
elif isinstance(event, h2.events.StreamEnded): elif isinstance(event, h2.events.StreamEnded):
h2_conn.send_headers(event.stream_id, [ headers = [
(':status', '200'), (':status', '200'),
('trailer', 'x-my-trailers') ]
]) if event.stream_id == 1:
# special stream_id to activate the Trailer announcement header
headers.append(('trailer', 'x-my-trailers'))
h2_conn.send_headers(event.stream_id, headers)
h2_conn.send_data(event.stream_id, b'response body') h2_conn.send_data(event.stream_id, b'response body')
h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True) h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
wfile.write(h2_conn.data_to_send()) wfile.write(h2_conn.data_to_send())
wfile.flush() wfile.flush()
return True return True
def test_trailers(self): @pytest.mark.parametrize('announce', [True, False])
def test_trailers(self, announce):
response_body_buffer = b'' response_body_buffer = b''
h2_conn = self.setup_connection() h2_conn = self.setup_connection()
self._send_request( self._send_request(
self.client.wfile, self.client.wfile,
h2_conn, h2_conn,
stream_id=(1 if announce else 3),
headers=[ headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'), (':method', 'GET'),
@ -1092,5 +1173,5 @@ class TestTrailers(_Http2Test):
assert self.master.state.flows[0].response.status_code == 200 assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.content == b'response body' assert self.master.state.flows[0].response.content == b'response body'
assert response_body_buffer == b'response body' assert response_body_buffer == b'response body'
assert self.master.state.flows[0].response.data.trailers['x-my-trailers'] == 'foobar' assert self.master.state.flows[0].response.trailers['x-my-trailers'] == 'foobar'
assert trailers_buffer == [(b'x-my-trailers', b'foobar')] assert trailers_buffer == [(b'x-my-trailers', b'foobar')]