Merge pull request #4042 from sanlengjingvv/develop

support HTTP/2 trailers
This commit is contained in:
Thomas Kriechbaumer 2020-07-06 17:14:17 +02:00 committed by GitHub
commit 46a0f69485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 386 additions and 73 deletions

View File

@ -0,0 +1,28 @@
"""
This script simply prints all received HTTP Trailers.
HTTP requests and responses can container trailing headers which are sent after
the body is fully transmitted. Such trailers need to be announced in the initial
headers by name, so the receiving endpoint can wait and read them after the
body.
"""
from mitmproxy import http
from mitmproxy.net.http import Headers
def request(flow: http.HTTPFlow):
if flow.request.trailers:
print("HTTP Trailers detected! Request contains:", flow.request.trailers)
def response(flow: http.HTTPFlow):
if flow.response.trailers:
print("HTTP Trailers detected! Response contains:", flow.response.trailers)
if flow.request.path == "/inject_trailers":
flow.response.headers["trailer"] = "x-my-injected-trailer-header"
flow.response.trailers = Headers([
(b"x-my-injected-trailer-header", b"foobar")
])
print("Injected a new trailer...", flow.response.headers["trailer"])

View File

@ -26,6 +26,7 @@ class HTTPRequest(http.Request):
http_version, http_version,
headers, headers,
content, content,
trailers=None,
timestamp_start=None, timestamp_start=None,
timestamp_end=None, timestamp_end=None,
is_replay=False, is_replay=False,
@ -41,6 +42,7 @@ class HTTPRequest(http.Request):
http_version, http_version,
headers, headers,
content, content,
trailers,
timestamp_start, timestamp_start,
timestamp_end, timestamp_end,
) )
@ -73,6 +75,7 @@ class HTTPRequest(http.Request):
http_version=request.data.http_version, http_version=request.data.http_version,
headers=request.data.headers, headers=request.data.headers,
content=request.data.content, content=request.data.content,
trailers=request.data.trailers,
timestamp_start=request.data.timestamp_start, timestamp_start=request.data.timestamp_start,
timestamp_end=request.data.timestamp_end, timestamp_end=request.data.timestamp_end,
) )
@ -97,6 +100,7 @@ class HTTPResponse(http.Response):
reason, reason,
headers, headers,
content, content,
trailers=None,
timestamp_start=None, timestamp_start=None,
timestamp_end=None, timestamp_end=None,
is_replay=False is_replay=False
@ -108,6 +112,7 @@ class HTTPResponse(http.Response):
reason, reason,
headers, headers,
content, content,
trailers,
timestamp_start=timestamp_start, timestamp_start=timestamp_start,
timestamp_end=timestamp_end, timestamp_end=timestamp_end,
) )
@ -127,6 +132,7 @@ class HTTPResponse(http.Response):
reason=response.data.reason, reason=response.data.reason,
headers=response.data.headers, headers=response.data.headers,
content=response.data.content, content=response.data.content,
trailers=response.data.trailers,
timestamp_start=response.data.timestamp_start, timestamp_start=response.data.timestamp_start,
timestamp_end=response.data.timestamp_end, timestamp_end=response.data.timestamp_end,
) )

View File

@ -172,6 +172,13 @@ def convert_6_7(data):
return data return data
def convert_7_8(data):
data["version"] = 8
data["request"]["trailers"] = None
data["response"]["trailers"] = None
return data
def _convert_dict_keys(o: Any) -> Any: def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict): if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
@ -226,6 +233,7 @@ converters = {
4: convert_4_5, 4: convert_4_5,
5: convert_5_6, 5: convert_5_6,
6: convert_6_7, 6: convert_6_7,
7: convert_7_8,
} }

View File

@ -59,7 +59,7 @@ def read_request_head(rfile):
timestamp_start = rfile.first_byte_timestamp timestamp_start = rfile.first_byte_timestamp
return request.Request( return request.Request(
form, method, scheme, host, port, path, http_version, headers, None, timestamp_start form, method, scheme, host, port, path, http_version, headers, None, None, timestamp_start
) )
@ -98,7 +98,7 @@ def read_response_head(rfile):
# more accurate timestamp_start # more accurate timestamp_start
timestamp_start = rfile.first_byte_timestamp timestamp_start = rfile.first_byte_timestamp
return response.Response(http_version, status_code, message, headers, None, timestamp_start) return response.Response(http_version, status_code, message, headers, None, None, timestamp_start)
def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):

View File

@ -28,6 +28,8 @@ class MessageData(serializable.Serializable):
def get_state(self): def get_state(self):
state = vars(self).copy() state = vars(self).copy()
state["headers"] = state["headers"].get_state() state["headers"] = state["headers"].get_state()
if 'trailers' in state and state["trailers"] is not None:
state["trailers"] = state["trailers"].get_state()
return state return state
@classmethod @classmethod
@ -53,6 +55,8 @@ class Message(serializable.Serializable):
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
state["headers"] = mheaders.Headers.from_state(state["headers"]) state["headers"] = mheaders.Headers.from_state(state["headers"])
if 'trailers' in state and state["trailers"] is not None:
state["trailers"] = mheaders.Headers.from_state(state["trailers"])
return cls(**state) return cls(**state)
@property @property
@ -130,6 +134,20 @@ class Message(serializable.Serializable):
content = property(get_content, set_content) content = property(get_content, set_content)
@property
def trailers(self):
"""
Message trailers object
Returns:
mitmproxy.net.http.Headers
"""
return self.data.trailers
@trailers.setter
def trailers(self, h):
self.data.trailers = h
@property @property
def http_version(self): def http_version(self):
""" """

View File

@ -29,6 +29,7 @@ class RequestData(message.MessageData):
http_version, http_version,
headers=(), headers=(),
content=None, content=None,
trailers=None,
timestamp_start=None, timestamp_start=None,
timestamp_end=None timestamp_end=None
): ):
@ -46,6 +47,8 @@ class RequestData(message.MessageData):
headers = nheaders.Headers(headers) headers = nheaders.Headers(headers)
if isinstance(content, str): if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
if trailers is not None and not isinstance(trailers, nheaders.Headers):
trailers = nheaders.Headers(trailers)
self.first_line_format = first_line_format self.first_line_format = first_line_format
self.method = method self.method = method
@ -56,6 +59,7 @@ class RequestData(message.MessageData):
self.http_version = http_version self.http_version = http_version
self.headers = headers self.headers = headers
self.content = content self.content = content
self.trailers = trailers
self.timestamp_start = timestamp_start self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end self.timestamp_end = timestamp_end

View File

@ -22,6 +22,7 @@ class ResponseData(message.MessageData):
reason=None, reason=None,
headers=(), headers=(),
content=None, content=None,
trailers=None,
timestamp_start=None, timestamp_start=None,
timestamp_end=None timestamp_end=None
): ):
@ -33,12 +34,15 @@ class ResponseData(message.MessageData):
headers = nheaders.Headers(headers) headers = nheaders.Headers(headers)
if isinstance(content, str): if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
if trailers is not None and not isinstance(trailers, nheaders.Headers):
trailers = nheaders.Headers(trailers)
self.http_version = http_version self.http_version = http_version
self.status_code = status_code self.status_code = status_code
self.reason = reason self.reason = reason
self.headers = headers self.headers = headers
self.content = content self.content = content
self.trailers = trailers
self.timestamp_start = timestamp_start self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end self.timestamp_end = timestamp_end

View File

@ -20,6 +20,9 @@ class _HttpTransmissionLayer(base.Layer):
def read_request_body(self, request): def read_request_body(self, request):
raise NotImplementedError() raise NotImplementedError()
def read_request_trailers(self, request):
raise NotImplementedError()
def send_request(self, request): def send_request(self, request):
raise NotImplementedError() raise NotImplementedError()
@ -30,11 +33,15 @@ class _HttpTransmissionLayer(base.Layer):
raise NotImplementedError() raise NotImplementedError()
yield "this is a generator" # pragma: no cover yield "this is a generator" # pragma: no cover
def read_response_trailers(self, request, response):
raise NotImplementedError()
def read_response(self, request): def read_response(self, request):
response = self.read_response_headers() response = self.read_response_headers()
response.data.content = b"".join( response.data.content = b"".join(
self.read_response_body(request, response) self.read_response_body(request, response)
) )
response.data.trailers = self.read_response_trailers(request, response)
return response return response
def send_response(self, response): def send_response(self, response):
@ -42,6 +49,7 @@ class _HttpTransmissionLayer(base.Layer):
raise exceptions.HttpException("Cannot assemble flow with missing content") raise exceptions.HttpException("Cannot assemble flow with missing content")
self.send_response_headers(response) self.send_response_headers(response)
self.send_response_body(response, [response.data.content]) self.send_response_body(response, [response.data.content])
self.send_response_trailers(response)
def send_response_headers(self, response): def send_response_headers(self, response):
raise NotImplementedError() raise NotImplementedError()
@ -49,6 +57,9 @@ class _HttpTransmissionLayer(base.Layer):
def send_response_body(self, response, chunks): def send_response_body(self, response, chunks):
raise NotImplementedError() raise NotImplementedError()
def send_response_trailers(self, response, chunks):
raise NotImplementedError()
def check_close_connection(self, f): def check_close_connection(self, f):
raise NotImplementedError() raise NotImplementedError()
@ -255,6 +266,7 @@ class HttpLayer(base.Layer):
f.request.data.content = b"".join( f.request.data.content = b"".join(
self.read_request_body(f.request) self.read_request_body(f.request)
) )
f.request.data.trailers = self.read_request_trailers(f.request)
f.request.timestamp_end = time.time() f.request.timestamp_end = time.time()
self.channel.ask("http_connect", f) self.channel.ask("http_connect", f)
@ -282,6 +294,9 @@ class HttpLayer(base.Layer):
f.request.data.content = None f.request.data.content = None
else: else:
f.request.data.content = b"".join(self.read_request_body(request)) f.request.data.content = b"".join(self.read_request_body(request))
f.request.data.trailers = self.read_request_trailers(f.request)
request.timestamp_end = time.time() request.timestamp_end = time.time()
except exceptions.HttpException as e: except exceptions.HttpException as e:
# We optimistically guess there might be an HTTP client on the # We optimistically guess there might be an HTTP client on the
@ -348,6 +363,8 @@ class HttpLayer(base.Layer):
else: else:
self.send_request_body(f.request, [f.request.data.content]) self.send_request_body(f.request, [f.request.data.content])
self.send_request_trailers(f.request)
f.response = self.read_response_headers() f.response = self.read_response_headers()
try: try:
@ -406,6 +423,8 @@ class HttpLayer(base.Layer):
# we now need to emulate the responseheaders hook. # we now need to emulate the responseheaders hook.
self.channel.ask("responseheaders", f) self.channel.ask("responseheaders", f)
f.response.data.trailers = self.read_response_trailers(f.request, f.response)
self.log("response", "debug", [repr(f.response)]) self.log("response", "debug", [repr(f.response)])
self.channel.ask("response", f) self.channel.ask("response", f)

View File

@ -23,6 +23,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
human.parse_size(self.config.options.body_size_limit) human.parse_size(self.config.options.body_size_limit)
) )
def read_request_trailers(self, request):
if "Trailer" in request.headers:
# TODO: not implemented yet
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
return None
def send_request_headers(self, request): def send_request_headers(self, request):
headers = http1.assemble_request_head(request) headers = http1.assemble_request_head(request)
self.server_conn.wfile.write(headers) self.server_conn.wfile.write(headers)
@ -33,7 +39,13 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
self.server_conn.wfile.write(chunk) self.server_conn.wfile.write(chunk)
self.server_conn.wfile.flush() self.server_conn.wfile.flush()
def send_request_trailers(self, request):
if "Trailer" in request.headers:
# TODO: not implemented yet
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
def send_request(self, request): def send_request(self, request):
# TODO: this does not yet support request trailers
self.server_conn.wfile.write(http1.assemble_request(request)) self.server_conn.wfile.write(http1.assemble_request(request))
self.server_conn.wfile.flush() self.server_conn.wfile.flush()
@ -49,6 +61,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
human.parse_size(self.config.options.body_size_limit) human.parse_size(self.config.options.body_size_limit)
) )
def read_response_trailers(self, request, response):
if "Trailer" in response.headers:
# TODO: not implemented yet
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
return None
def send_response_headers(self, response): def send_response_headers(self, response):
raw = http1.assemble_response_head(response) raw = http1.assemble_response_head(response)
self.client_conn.wfile.write(raw) self.client_conn.wfile.write(raw)
@ -59,6 +77,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
self.client_conn.wfile.write(chunk) self.client_conn.wfile.write(chunk)
self.client_conn.wfile.flush() self.client_conn.wfile.flush()
def send_response_trailers(self, response):
if "Trailer" in response.headers:
# TODO: not implemented yet
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
return
def check_close_connection(self, flow): def check_close_connection(self, flow):
request_close = http1.connection_close( request_close = http1.connection_close(
flow.request.http_version, flow.request.http_version,

View File

@ -55,7 +55,7 @@ class SafeH2Connection(connection.H2Connection):
self.send_headers(stream_id, headers.fields, **kwargs) self.send_headers(stream_id, headers.fields, **kwargs)
self.conn.send(self.data_to_send()) self.conn.send(self.data_to_send())
def safe_send_body(self, raise_zombie: Callable, stream_id: int, chunks: List[bytes]): def safe_send_body(self, raise_zombie: Callable, stream_id: int, chunks: List[bytes], end_stream=True):
for chunk in chunks: for chunk in chunks:
position = 0 position = 0
while position < len(chunk): while position < len(chunk):
@ -75,6 +75,7 @@ class SafeH2Connection(connection.H2Connection):
finally: finally:
self.lock.release() self.lock.release()
position += max_outbound_frame_size position += max_outbound_frame_size
if end_stream:
with self.lock: with self.lock:
raise_zombie() raise_zombie()
self.end_stream(stream_id) self.end_stream(stream_id)
@ -170,7 +171,7 @@ class Http2Layer(base.Layer):
elif isinstance(event, events.PriorityUpdated): elif isinstance(event, events.PriorityUpdated):
return self._handle_priority_updated(eid, event) return self._handle_priority_updated(eid, event)
elif isinstance(event, events.TrailersReceived): elif isinstance(event, events.TrailersReceived):
raise NotImplementedError('TrailersReceived not implemented') return self._handle_trailers(eid, event, is_server, other_conn)
# fail-safe for unhandled events # fail-safe for unhandled events
return True return True
@ -179,22 +180,21 @@ class Http2Layer(base.Layer):
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers) self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers)
self.streams[eid].timestamp_start = time.time() self.streams[eid].timestamp_start = time.time()
self.streams[eid].no_body = (event.stream_ended is not None)
if event.priority_updated is not None: if event.priority_updated is not None:
self.streams[eid].priority_exclusive = event.priority_updated.exclusive self.streams[eid].priority_exclusive = event.priority_updated.exclusive
self.streams[eid].priority_depends_on = event.priority_updated.depends_on self.streams[eid].priority_depends_on = event.priority_updated.depends_on
self.streams[eid].priority_weight = event.priority_updated.weight self.streams[eid].priority_weight = event.priority_updated.weight
self.streams[eid].handled_priority_event = event.priority_updated self.streams[eid].handled_priority_event = event.priority_updated
self.streams[eid].start() self.streams[eid].start()
self.streams[eid].request_arrived.set() self.streams[eid].request_message.arrived.set()
return True return True
def _handle_response_received(self, eid, event): def _handle_response_received(self, eid, event):
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid].queued_data_length = 0 self.streams[eid].queued_data_length = 0
self.streams[eid].timestamp_start = time.time() self.streams[eid].timestamp_start = time.time()
self.streams[eid].response_headers = headers self.streams[eid].response_message.headers = headers
self.streams[eid].response_arrived.set() self.streams[eid].response_message.arrived.set()
return True return True
def _handle_data_received(self, eid, event, source_conn): def _handle_data_received(self, eid, event, source_conn):
@ -219,7 +219,7 @@ class Http2Layer(base.Layer):
def _handle_stream_ended(self, eid): def _handle_stream_ended(self, eid):
self.streams[eid].timestamp_end = time.time() self.streams[eid].timestamp_end = time.time()
self.streams[eid].data_finished.set() self.streams[eid].stream_ended.set()
return True return True
def _handle_stream_reset(self, eid, event, is_server, other_conn): def _handle_stream_reset(self, eid, event, is_server, other_conn):
@ -233,6 +233,11 @@ class Http2Layer(base.Layer):
self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code) self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code)
return True return True
def _handle_trailers(self, eid, event, is_server, other_conn):
trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid].trailers = trailers
return True
def _handle_remote_settings_changed(self, event, other_conn): def _handle_remote_settings_changed(self, event, other_conn):
new_settings = dict([(key, cs.new_value) for (key, cs) in event.changed_settings.items()]) new_settings = dict([(key, cs.new_value) for (key, cs) in event.changed_settings.items()])
self.connections[other_conn].safe_update_settings(new_settings) self.connections[other_conn].safe_update_settings(new_settings)
@ -277,8 +282,8 @@ class Http2Layer(base.Layer):
self.streams[event.pushed_stream_id].pushed = True self.streams[event.pushed_stream_id].pushed = True
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
self.streams[event.pushed_stream_id].timestamp_end = time.time() self.streams[event.pushed_stream_id].timestamp_end = time.time()
self.streams[event.pushed_stream_id].request_arrived.set() self.streams[event.pushed_stream_id].request_message.arrived.set()
self.streams[event.pushed_stream_id].request_data_finished.set() self.streams[event.pushed_stream_id].request_message.stream_ended.set()
self.streams[event.pushed_stream_id].start() self.streams[event.pushed_stream_id].start()
return True return True
@ -392,6 +397,16 @@ def detect_zombie_stream(func): # pragma: no cover
class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThread): class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThread):
class Message:
def __init__(self, headers=None):
self.headers: Optional[mitmproxy.net.http.Headers] = headers # headers are the first thing to be received on a new stream
self.data_queue: queue.Queue[bytes] = queue.Queue() # contains raw contents of DATA frames
self.queued_data_length = 0 # used to enforce mitmproxy's config.options.body_size_limit
self.trailers: Optional[mitmproxy.net.http.Headers] = None # trailers are received after stream_ended is set
self.arrived = threading.Event() # indicates the HEADERS+CONTINUTATION frames have been received
self.stream_ended = threading.Event() # indicates the a frame with the END_STREAM flag has been received
def __init__(self, ctx, h2_connection, stream_id: int, request_headers: mitmproxy.net.http.Headers) -> None: def __init__(self, ctx, h2_connection, stream_id: int, request_headers: mitmproxy.net.http.Headers) -> None:
super().__init__( super().__init__(
ctx, name="Http2SingleStreamLayer-{}".format(stream_id) ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
@ -400,24 +415,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.zombie: Optional[float] = None self.zombie: Optional[float] = None
self.client_stream_id: int = stream_id self.client_stream_id: int = stream_id
self.server_stream_id: Optional[int] = None self.server_stream_id: Optional[int] = None
self.request_headers = request_headers
self.response_headers: Optional[mitmproxy.net.http.Headers] = None
self.pushed = False self.pushed = False
self.timestamp_start: Optional[float] = None self.timestamp_start: Optional[float] = None
self.timestamp_end: Optional[float] = None self.timestamp_end: Optional[float] = None
self.request_arrived = threading.Event() self.request_message = self.Message(request_headers)
self.request_data_queue: queue.Queue[bytes] = queue.Queue() self.response_message = self.Message()
self.request_queued_data_length = 0
self.request_data_finished = threading.Event()
self.response_arrived = threading.Event()
self.response_data_queue: queue.Queue[bytes] = queue.Queue()
self.response_queued_data_length = 0
self.response_data_finished = threading.Event()
self.no_body = False
self.priority_exclusive: bool self.priority_exclusive: bool
self.priority_depends_on: Optional[int] = None self.priority_depends_on: Optional[int] = None
@ -427,10 +431,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
def kill(self): def kill(self):
if not self.zombie: if not self.zombie:
self.zombie = time.time() self.zombie = time.time()
self.request_data_finished.set() self.request_message.stream_ended.set()
self.request_arrived.set() self.request_message.arrived.set()
self.response_arrived.set() self.response_message.arrived.set()
self.response_data_finished.set() self.response_message.stream_ended.set()
def connect(self): # pragma: no cover def connect(self): # pragma: no cover
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.") raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
@ -448,28 +452,44 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@property @property
def data_queue(self): def data_queue(self):
if self.response_arrived.is_set(): if self.response_message.arrived.is_set():
return self.response_data_queue return self.response_message.data_queue
else: else:
return self.request_data_queue return self.request_message.data_queue
@property @property
def queued_data_length(self): def queued_data_length(self):
if self.response_arrived.is_set(): if self.response_message.arrived.is_set():
return self.response_queued_data_length return self.response_message.queued_data_length
else: else:
return self.request_queued_data_length return self.request_message.queued_data_length
@queued_data_length.setter @queued_data_length.setter
def queued_data_length(self, v): def queued_data_length(self, v):
self.request_queued_data_length = v self.request_message.queued_data_length = v
@property @property
def data_finished(self): def stream_ended(self):
if self.response_arrived.is_set(): # This indicates that all message headers, the full message body, and all trailers have been received
return self.response_data_finished # https://tools.ietf.org/html/rfc7540#section-8.1
if self.response_message.arrived.is_set():
return self.response_message.stream_ended
else: else:
return self.request_data_finished return self.request_message.stream_ended
@property
def trailers(self):
if self.response_message.arrived.is_set():
return self.response_message.trailers
else:
return self.request_message.trailers
@trailers.setter
def trailers(self, v):
if self.response_message.arrived.is_set():
self.response_message.trailers = v
else:
self.request_message.trailers = v
def raise_zombie(self, pre_command=None): # pragma: no cover def raise_zombie(self, pre_command=None): # pragma: no cover
connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED
@ -480,13 +500,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@detect_zombie_stream @detect_zombie_stream
def read_request_headers(self, flow): def read_request_headers(self, flow):
self.request_arrived.wait() self.request_message.arrived.wait()
self.raise_zombie() self.raise_zombie()
if self.pushed: if self.pushed:
flow.metadata['h2-pushed-stream'] = True flow.metadata['h2-pushed-stream'] = True
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers) first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_message.headers)
return http.HTTPRequest( return http.HTTPRequest(
first_line_format, first_line_format,
method, method,
@ -495,7 +515,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
port, port,
path, path,
b"HTTP/2.0", b"HTTP/2.0",
self.request_headers, self.request_message.headers,
None, None,
timestamp_start=self.timestamp_start, timestamp_start=self.timestamp_start,
timestamp_end=self.timestamp_end, timestamp_end=self.timestamp_end,
@ -504,20 +524,24 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@detect_zombie_stream @detect_zombie_stream
def read_request_body(self, request): def read_request_body(self, request):
if not request.stream: if not request.stream:
self.request_data_finished.wait() self.request_message.stream_ended.wait()
while True: while True:
try: try:
yield self.request_data_queue.get(timeout=0.1) yield self.request_message.data_queue.get(timeout=0.1)
except queue.Empty: # pragma: no cover except queue.Empty: # pragma: no cover
pass pass
if self.request_data_finished.is_set(): if self.request_message.stream_ended.is_set():
self.raise_zombie() self.raise_zombie()
while self.request_data_queue.qsize() > 0: while self.request_message.data_queue.qsize() > 0:
yield self.request_data_queue.get() yield self.request_message.data_queue.get()
break break
self.raise_zombie() self.raise_zombie()
@detect_zombie_stream
def read_request_trailers(self, request):
return self.request_message.trailers
@detect_zombie_stream @detect_zombie_stream
def send_request_headers(self, request): def send_request_headers(self, request):
if self.pushed: if self.pushed:
@ -567,7 +591,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.raise_zombie, self.raise_zombie,
self.server_stream_id, self.server_stream_id,
headers, headers,
end_stream=self.no_body,
priority_exclusive=priority_exclusive, priority_exclusive=priority_exclusive,
priority_depends_on=priority_depends_on, priority_depends_on=priority_depends_on,
priority_weight=priority_weight, priority_weight=priority_weight,
@ -584,26 +607,31 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
# nothing to do here # nothing to do here
return return
if not self.no_body:
self.connections[self.server_conn].safe_send_body( self.connections[self.server_conn].safe_send_body(
self.raise_zombie, self.raise_zombie,
self.server_stream_id, self.server_stream_id,
chunks chunks,
end_stream=(request.trailers is None),
) )
@detect_zombie_stream @detect_zombie_stream
def send_request(self, message): def send_request_trailers(self, request):
self.send_request_headers(message) self._send_trailers(self.server_conn, request.trailers)
self.send_request_body(message, [message.content])
@detect_zombie_stream
def send_request(self, request):
self.send_request_headers(request)
self.send_request_body(request, [request.content])
self.send_request_trailers(request)
@detect_zombie_stream @detect_zombie_stream
def read_response_headers(self): def read_response_headers(self):
self.response_arrived.wait() self.response_message.arrived.wait()
self.raise_zombie() self.raise_zombie()
status_code = int(self.response_headers.get(':status', 502)) status_code = int(self.response_message.headers.get(':status', 502))
headers = self.response_headers.copy() headers = self.response_message.headers.copy()
headers.pop(":status", None) headers.pop(":status", None)
return http.HTTPResponse( return http.HTTPResponse(
@ -620,16 +648,20 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
def read_response_body(self, request, response): def read_response_body(self, request, response):
while True: while True:
try: try:
yield self.response_data_queue.get(timeout=0.1) yield self.response_message.data_queue.get(timeout=0.1)
except queue.Empty: # pragma: no cover except queue.Empty: # pragma: no cover
pass pass
if self.response_data_finished.is_set(): if self.response_message.stream_ended.is_set():
self.raise_zombie() self.raise_zombie()
while self.response_data_queue.qsize() > 0: while self.response_message.data_queue.qsize() > 0:
yield self.response_data_queue.get() yield self.response_message.data_queue.get()
break break
self.raise_zombie() self.raise_zombie()
@detect_zombie_stream
def read_response_trailers(self, request, response):
return self.response_message.trailers
@detect_zombie_stream @detect_zombie_stream
def send_response_headers(self, response): def send_response_headers(self, response):
headers = response.headers.copy() headers = response.headers.copy()
@ -642,11 +674,27 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
) )
@detect_zombie_stream @detect_zombie_stream
def send_response_body(self, _response, chunks): def send_response_body(self, response, chunks):
self.connections[self.client_conn].safe_send_body( self.connections[self.client_conn].safe_send_body(
self.raise_zombie, self.raise_zombie,
self.client_stream_id, self.client_stream_id,
chunks chunks,
end_stream=(response.trailers is None),
)
@detect_zombie_stream
def send_response_trailers(self, response):
self._send_trailers(self.client_conn, response.trailers)
def _send_trailers(self, conn, trailers):
if not trailers:
return
with self.connections[conn].lock:
self.connections[conn].safe_send_headers(
self.raise_zombie,
self.client_stream_id,
trailers,
end_stream=True
) )
def __call__(self): # pragma: no cover def __call__(self): # pragma: no cover

View File

@ -93,6 +93,9 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"timestamp_end": flow.response.timestamp_end, "timestamp_end": flow.response.timestamp_end,
"is_replay": flow.response.is_replay, "is_replay": flow.response.is_replay,
} }
if flow.response.data.trailers:
f["response"]["trailers"] = tuple(flow.response.data.trailers.items(True))
f.get("server_conn", {}).pop("cert", None) f.get("server_conn", {}).pop("cert", None)
f.get("client_conn", {}).pop("mitmcert", None) f.get("client_conn", {}).pop("mitmcert", None)

View File

@ -8,7 +8,7 @@ MITMPROXY = "mitmproxy " + VERSION
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one # Serialization format version. This is displayed nowhere, it just needs to be incremented by one
# for each change in the file format. # for each change in the file format.
FLOW_FORMAT_VERSION = 7 FLOW_FORMAT_VERSION = 8
def get_dev_version() -> str: def get_dev_version() -> str:

View File

@ -110,8 +110,9 @@ class HTTP2StateProtocol:
b"HTTP/2.0", b"HTTP/2.0",
headers, headers,
body, body,
timestamp_start, None,
timestamp_end, timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
) )
request.stream_id = stream_id request.stream_id = stream_id

View File

@ -21,8 +21,11 @@ class TestRequestData:
treq(headers="foobar") treq(headers="foobar")
with pytest.raises(ValueError): with pytest.raises(ValueError):
treq(content="foobar") treq(content="foobar")
with pytest.raises(ValueError):
treq(trailers="foobar")
assert isinstance(treq(headers=()).headers, Headers) assert isinstance(treq(headers=()).headers, Headers)
assert isinstance(treq(trailers=()).trailers, Headers)
class TestRequestCore: class TestRequestCore:

View File

@ -20,8 +20,11 @@ class TestResponseData:
tresp(reason="fööbär") tresp(reason="fööbär")
with pytest.raises(ValueError): with pytest.raises(ValueError):
tresp(content="foobar") tresp(content="foobar")
with pytest.raises(ValueError):
tresp(trailers="foobar")
assert isinstance(tresp(headers=()).headers, Headers) assert isinstance(tresp(headers=()).headers, Headers)
assert isinstance(tresp(trailers=()).trailers, Headers)
class TestResponseCore: class TestResponseCore:

View File

@ -1031,3 +1031,147 @@ class TestResponseStreaming(_Http2Test):
assert data assert data
else: else:
assert data is None assert data is None
class TestRequestTrailers(_Http2Test):
server_trailers_received = False
@classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.RequestReceived):
# reset the value for a fresh test
cls.server_trailers_received = False
elif isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.TrailersReceived):
cls.server_trailers_received = True
elif isinstance(event, h2.events.StreamEnded):
h2_conn.send_headers(event.stream_id, [
(':status', '200'),
('x-my-trailer-request-received', 'success' if cls.server_trailers_received else "failure"),
], end_stream=True)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
@pytest.mark.parametrize('announce', [True, False])
@pytest.mark.parametrize('body', [None, b"foobar"])
def test_trailers(self, announce, body):
h2_conn = self.setup_connection()
stream_id = 1
headers = [
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
]
if announce:
headers.append(('trailer', 'x-my-trailers'))
h2_conn.send_headers(
stream_id=stream_id,
headers=headers,
)
if body:
h2_conn.send_data(stream_id, body)
# send trailers
h2_conn.send_headers(stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
done = False
while not done:
try:
raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamEnded):
done = True
h2_conn.close_connection()
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].request.trailers['x-my-trailers'] == 'foobar'
assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.headers['x-my-trailer-request-received'] == 'success'
class TestResponseTrailers(_Http2Test):
@classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.StreamEnded):
headers = [
(':status', '200'),
]
if event.stream_id == 1:
# special stream_id to activate the Trailer announcement header
headers.append(('trailer', 'x-my-trailers'))
h2_conn.send_headers(event.stream_id, headers)
h2_conn.send_data(event.stream_id, b'response body')
h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
@pytest.mark.parametrize('announce', [True, False])
def test_trailers(self, announce):
response_body_buffer = b''
h2_conn = self.setup_connection()
self._send_request(
self.client.wfile,
h2_conn,
stream_id=(1 if announce else 3),
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
])
trailers_buffer = None
done = False
while not done:
try:
raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.DataReceived):
response_body_buffer += event.data
elif isinstance(event, h2.events.TrailersReceived):
trailers_buffer = event.headers
elif isinstance(event, h2.events.StreamEnded):
done = True
h2_conn.close_connection()
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.content == b'response body'
assert response_body_buffer == b'response body'
assert self.master.state.flows[0].response.trailers['x-my-trailers'] == 'foobar'
assert trailers_buffer == [(b'x-my-trailers', b'foobar')]