unify HTTP trailers APIs

This commit is contained in:
Thomas Kriechbaumer 2020-07-02 16:05:02 +02:00
parent d589f13a1d
commit ebb061796c
13 changed files with 137 additions and 66 deletions

View File

@ -26,6 +26,7 @@ class HTTPRequest(http.Request):
http_version,
headers,
content,
trailers=None,
timestamp_start=None,
timestamp_end=None,
is_replay=False,
@ -41,6 +42,7 @@ class HTTPRequest(http.Request):
http_version,
headers,
content,
trailers,
timestamp_start,
timestamp_end,
)
@ -73,6 +75,7 @@ class HTTPRequest(http.Request):
http_version=request.data.http_version,
headers=request.data.headers,
content=request.data.content,
trailers=request.data.trailers,
timestamp_start=request.data.timestamp_start,
timestamp_end=request.data.timestamp_end,
)

View File

@ -174,6 +174,7 @@ def convert_6_7(data):
def convert_7_8(data):
data["version"] = 8
data["request"]["trailers"] = None
data["response"]["trailers"] = None
return data

View File

@ -59,7 +59,7 @@ def read_request_head(rfile):
timestamp_start = rfile.first_byte_timestamp
return request.Request(
form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
form, method, scheme, host, port, path, http_version, headers, None, None, timestamp_start
)

View File

@ -134,6 +134,20 @@ class Message(serializable.Serializable):
content = property(get_content, set_content)
@property
def trailers(self):
"""
Message trailers object
Returns:
mitmproxy.net.http.Headers
"""
return self.data.trailers
@trailers.setter
def trailers(self, h):
self.data.trailers = h
@property
def http_version(self):
"""

View File

@ -29,6 +29,7 @@ class RequestData(message.MessageData):
http_version,
headers=(),
content=None,
trailers=None,
timestamp_start=None,
timestamp_end=None
):
@ -46,6 +47,8 @@ class RequestData(message.MessageData):
headers = nheaders.Headers(headers)
if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
if trailers is not None and not isinstance(trailers, nheaders.Headers):
trailers = nheaders.Headers(trailers)
self.first_line_format = first_line_format
self.method = method
@ -56,6 +59,7 @@ class RequestData(message.MessageData):
self.http_version = http_version
self.headers = headers
self.content = content
self.trailers = trailers
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end

View File

@ -34,6 +34,8 @@ class ResponseData(message.MessageData):
headers = nheaders.Headers(headers)
if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
if trailers is not None and not isinstance(trailers, nheaders.Headers):
trailers = nheaders.Headers(trailers)
self.http_version = http_version
self.status_code = status_code

View File

@ -20,6 +20,9 @@ class _HttpTransmissionLayer(base.Layer):
def read_request_body(self, request):
raise NotImplementedError()
def read_request_trailers(self, request):
raise NotImplementedError()
def send_request(self, request):
raise NotImplementedError()
@ -30,11 +33,15 @@ class _HttpTransmissionLayer(base.Layer):
raise NotImplementedError()
yield "this is a generator" # pragma: no cover
def read_response_trailers(self, request, response):
raise NotImplementedError()
def read_response(self, request):
response = self.read_response_headers()
response.data.content = b"".join(
self.read_response_body(request, response)
)
response.data.trailers = self.read_response_trailers(request, response)
return response
def send_response(self, response):
@ -42,6 +49,7 @@ class _HttpTransmissionLayer(base.Layer):
raise exceptions.HttpException("Cannot assemble flow with missing content")
self.send_response_headers(response)
self.send_response_body(response, [response.data.content])
self.send_response_trailers(response)
def send_response_headers(self, response):
raise NotImplementedError()
@ -49,6 +57,9 @@ class _HttpTransmissionLayer(base.Layer):
def send_response_body(self, response, chunks):
raise NotImplementedError()
def send_response_trailers(self, response, chunks):
raise NotImplementedError()
def check_close_connection(self, f):
raise NotImplementedError()
@ -255,6 +266,7 @@ class HttpLayer(base.Layer):
f.request.data.content = b"".join(
self.read_request_body(f.request)
)
f.request.data.trailers = self.read_request_trailers(f.request)
f.request.timestamp_end = time.time()
self.channel.ask("http_connect", f)
@ -282,6 +294,9 @@ class HttpLayer(base.Layer):
f.request.data.content = None
else:
f.request.data.content = b"".join(self.read_request_body(request))
f.request.data.trailers = self.read_request_trailers(f.request)
request.timestamp_end = time.time()
except exceptions.HttpException as e:
# We optimistically guess there might be an HTTP client on the
@ -348,6 +363,8 @@ class HttpLayer(base.Layer):
else:
self.send_request_body(f.request, [f.request.data.content])
self.send_request_trailers(f.request)
f.response = self.read_response_headers()
try:
@ -406,10 +423,9 @@ class HttpLayer(base.Layer):
# we now need to emulate the responseheaders hook.
self.channel.ask("responseheaders", f)
f.response.data.trailers = self.read_response_trailers(f.request, f.response)
self.log("response", "debug", [repr(f.response)])
# not support HTTP/1.1 trailers
if f.request.http_version == "HTTP/2.0":
f.response.data.trailers = self.read_trailers_headers()
self.channel.ask("response", f)
if not f.response.stream:

View File

@ -23,6 +23,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
human.parse_size(self.config.options.body_size_limit)
)
def read_request_trailers(self, request):
if "Trailer" in request.headers:
# TODO: not implemented yet
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
return None
def send_request_headers(self, request):
headers = http1.assemble_request_head(request)
self.server_conn.wfile.write(headers)
@ -33,7 +39,13 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
self.server_conn.wfile.write(chunk)
self.server_conn.wfile.flush()
def send_request_trailers(self, request):
if "Trailer" in request.headers:
# TODO: not implemented yet
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
def send_request(self, request):
# TODO: this does not yet support request trailers
self.server_conn.wfile.write(http1.assemble_request(request))
self.server_conn.wfile.flush()
@ -49,6 +61,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
human.parse_size(self.config.options.body_size_limit)
)
def read_response_trailers(self, request, response):
if "Trailer" in response.headers:
# TODO: not implemented yet
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
return None
def send_response_headers(self, response):
raw = http1.assemble_response_head(response)
self.client_conn.wfile.write(raw)
@ -59,6 +77,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
self.client_conn.wfile.write(chunk)
self.client_conn.wfile.flush()
def send_response_trailers(self, response):
if "Trailer" in response.headers:
# TODO: not implemented yet
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
return
def check_close_connection(self, flow):
request_close = http1.connection_close(
flow.request.http_version,

View File

@ -235,8 +235,10 @@ class Http2Layer(base.Layer):
return True
def _handle_trailers(self, eid, event, is_server, other_conn):
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid].update_trailers(headers)
trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
# TODO: support request trailers as well!
self.streams[eid].response_trailers = trailers
self.streams[eid].response_trailers_arrived.set()
return True
def _handle_remote_settings_changed(self, event, other_conn):
@ -417,15 +419,17 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.request_data_queue: queue.Queue[bytes] = queue.Queue()
self.request_queued_data_length = 0
self.request_data_finished = threading.Event()
self.request_trailers_arrived = threading.Event()
self.request_trailers = None
self.response_arrived = threading.Event()
self.response_data_queue: queue.Queue[bytes] = queue.Queue()
self.response_queued_data_length = 0
self.response_data_finished = threading.Event()
self.response_trailers_arrived = threading.Event()
self.response_trailers = None
self.no_body = False
self.has_trailers = False
self.trailers_header = None
self.priority_exclusive: bool
self.priority_depends_on: Optional[int] = None
@ -437,8 +441,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.zombie = time.time()
self.request_data_finished.set()
self.request_arrived.set()
self.request_trailers_arrived.set()
self.response_arrived.set()
self.response_data_finished.set()
self.response_trailers_arrived.set()
def connect(self): # pragma: no cover
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
@ -526,6 +532,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
break
self.raise_zombie()
@detect_zombie_stream
def read_request_trailers(self, request):
if "trailer" in request.headers:
self.request_trailers_arrived.wait()
self.raise_zombie()
return self.request_trailers
return None
@detect_zombie_stream
def send_request_headers(self, request):
if self.pushed:
@ -600,25 +614,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
)
@detect_zombie_stream
def update_trailers(self, headers):
self.trailers_header = headers
self.has_trailers = True
def send_request_trailers(self, request):
self._send_trailers(self.server_conn, self.request_trailers)
@detect_zombie_stream
def send_trailers_headers(self):
if self.has_trailers and self.trailers_header:
with self.connections[self.client_conn].lock:
self.connections[self.client_conn].safe_send_headers(
self.raise_zombie,
self.client_stream_id,
self.trailers_header,
end_stream = True
)
@detect_zombie_stream
def send_request(self, message):
self.send_request_headers(message)
self.send_request_body(message, [message.content])
def send_request(self, request):
self.send_request_headers(request)
self.send_request_body(request, [request.content])
self.send_request_trailers(request)
@detect_zombie_stream
def read_response_headers(self):
@ -640,10 +643,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
timestamp_end=self.timestamp_end,
)
@detect_zombie_stream
def read_trailers_headers(self):
return self.trailers_header
@detect_zombie_stream
def read_response_body(self, request, response):
while True:
@ -658,6 +657,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
break
self.raise_zombie()
@detect_zombie_stream
def read_response_trailers(self, request, response):
if "trailer" in response.headers:
self.response_trailers_arrived.wait()
self.raise_zombie()
return self.response_trailers
return None
@detect_zombie_stream
def send_response_headers(self, response):
headers = response.headers.copy()
@ -670,15 +677,28 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
)
@detect_zombie_stream
def send_response_body(self, _response, chunks):
def send_response_body(self, response, chunks):
self.connections[self.client_conn].safe_send_body(
self.raise_zombie,
self.client_stream_id,
chunks,
end_stream = not self.has_trailers
end_stream=("trailer" not in response.headers)
)
if self.has_trailers:
self.send_trailers_headers()
@detect_zombie_stream
def send_response_trailers(self, _response):
self._send_trailers(self.client_conn, self.response_trailers)
def _send_trailers(self, conn, trailers):
if not trailers:
return
with self.connections[conn].lock:
self.connections[conn].safe_send_headers(
self.raise_zombie,
self.client_stream_id,
trailers,
end_stream=True
)
def __call__(self): # pragma: no cover
raise EnvironmentError('Http2SingleStreamLayer must be run as thread')

View File

@ -110,8 +110,9 @@ class HTTP2StateProtocol:
b"HTTP/2.0",
headers,
body,
timestamp_start,
timestamp_end,
None,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
request.stream_id = stream_id

View File

@ -21,8 +21,11 @@ class TestRequestData:
treq(headers="foobar")
with pytest.raises(ValueError):
treq(content="foobar")
with pytest.raises(ValueError):
treq(trailers="foobar")
assert isinstance(treq(headers=()).headers, Headers)
assert isinstance(treq(trailers=()).trailers, Headers)
class TestRequestCore:

View File

@ -20,8 +20,11 @@ class TestResponseData:
tresp(reason="fööbär")
with pytest.raises(ValueError):
tresp(content="foobar")
with pytest.raises(ValueError):
tresp(trailers="foobar")
assert isinstance(tresp(headers=()).headers, Headers)
assert isinstance(tresp(trailers=()).trailers, Headers)
class TestResponseCore:

View File

@ -1034,37 +1034,19 @@ class TestResponseStreaming(_Http2Test):
class TestTrailers(_Http2Test):
request_body_buffer = b''
@classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
assert (b'self.client-foo', b'self.client-bar-1') in event.headers
assert (b'self.client-foo', b'self.client-bar-2') in event.headers
elif isinstance(event, h2.events.StreamEnded):
import warnings
with warnings.catch_warnings():
# Ignore UnicodeWarning:
# h2/utilities.py:64: UnicodeWarning: Unicode equal comparison
# failed to convert both arguments to Unicode - interpreting
# them as being unequal.
# elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20:
warnings.simplefilter("ignore")
h2_conn.send_headers(event.stream_id, [
(':status', '200'),
('server-foo', 'server-bar'),
('föo', 'bär'),
('X-Stream-ID', str(event.stream_id)),
])
h2_conn.send_headers(event.stream_id, [
(':status', '200'),
('trailer', 'x-my-trailers')
])
h2_conn.send_data(event.stream_id, b'response body')
h2_conn.send_headers(event.stream_id, [('trailers', 'trailers-foo')], end_stream=True)
h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
wfile.write(h2_conn.data_to_send())
wfile.flush()
elif isinstance(event, h2.events.DataReceived):
cls.request_body_buffer += event.data
return True
def test_trailers(self):
@ -1079,11 +1061,9 @@ class TestTrailers(_Http2Test):
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
('self.client-FoO', 'self.client-bar-1'),
('self.client-FoO', 'self.client-bar-2'),
],
body=b'request body')
])
trailers_buffer = None
done = False
while not done:
try:
@ -1099,6 +1079,8 @@ class TestTrailers(_Http2Test):
for event in events:
if isinstance(event, h2.events.DataReceived):
response_body_buffer += event.data
elif isinstance(event, h2.events.TrailersReceived):
trailers_buffer = event.headers
elif isinstance(event, h2.events.StreamEnded):
done = True
@ -1108,9 +1090,7 @@ class TestTrailers(_Http2Test):
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.headers['server-foo'] == 'server-bar'
assert self.master.state.flows[0].response.headers['föo'] == 'bär'
assert self.master.state.flows[0].response.content == b'response body'
assert self.request_body_buffer == b'request body'
assert response_body_buffer == b'response body'
assert self.master.state.flows[0].response.data.trailers['trailers'] == 'trailers-foo'
assert self.master.state.flows[0].response.data.trailers['x-my-trailers'] == 'foobar'
assert trailers_buffer == [(b'x-my-trailers', b'foobar')]