Merge pull request #4042 from sanlengjingvv/develop

support HTTP/2 trailers
This commit is contained in:
Thomas Kriechbaumer 2020-07-06 17:14:17 +02:00 committed by GitHub
commit 46a0f69485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 386 additions and 73 deletions

View File

@ -0,0 +1,28 @@
"""
This script simply prints all received HTTP Trailers.
HTTP requests and responses can container trailing headers which are sent after
the body is fully transmitted. Such trailers need to be announced in the initial
headers by name, so the receiving endpoint can wait and read them after the
body.
"""
from mitmproxy import http
from mitmproxy.net.http import Headers
def request(flow: http.HTTPFlow):
if flow.request.trailers:
print("HTTP Trailers detected! Request contains:", flow.request.trailers)
def response(flow: http.HTTPFlow):
if flow.response.trailers:
print("HTTP Trailers detected! Response contains:", flow.response.trailers)
if flow.request.path == "/inject_trailers":
flow.response.headers["trailer"] = "x-my-injected-trailer-header"
flow.response.trailers = Headers([
(b"x-my-injected-trailer-header", b"foobar")
])
print("Injected a new trailer...", flow.response.headers["trailer"])

View File

@ -26,6 +26,7 @@ class HTTPRequest(http.Request):
http_version,
headers,
content,
trailers=None,
timestamp_start=None,
timestamp_end=None,
is_replay=False,
@ -41,6 +42,7 @@ class HTTPRequest(http.Request):
http_version,
headers,
content,
trailers,
timestamp_start,
timestamp_end,
)
@ -73,6 +75,7 @@ class HTTPRequest(http.Request):
http_version=request.data.http_version,
headers=request.data.headers,
content=request.data.content,
trailers=request.data.trailers,
timestamp_start=request.data.timestamp_start,
timestamp_end=request.data.timestamp_end,
)
@ -97,6 +100,7 @@ class HTTPResponse(http.Response):
reason,
headers,
content,
trailers=None,
timestamp_start=None,
timestamp_end=None,
is_replay=False
@ -108,6 +112,7 @@ class HTTPResponse(http.Response):
reason,
headers,
content,
trailers,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
@ -127,6 +132,7 @@ class HTTPResponse(http.Response):
reason=response.data.reason,
headers=response.data.headers,
content=response.data.content,
trailers=response.data.trailers,
timestamp_start=response.data.timestamp_start,
timestamp_end=response.data.timestamp_end,
)

View File

@ -172,6 +172,13 @@ def convert_6_7(data):
return data
def convert_7_8(data):
data["version"] = 8
data["request"]["trailers"] = None
data["response"]["trailers"] = None
return data
def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
@ -226,6 +233,7 @@ converters = {
4: convert_4_5,
5: convert_5_6,
6: convert_6_7,
7: convert_7_8,
}

View File

@ -59,7 +59,7 @@ def read_request_head(rfile):
timestamp_start = rfile.first_byte_timestamp
return request.Request(
form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
form, method, scheme, host, port, path, http_version, headers, None, None, timestamp_start
)
@ -98,7 +98,7 @@ def read_response_head(rfile):
# more accurate timestamp_start
timestamp_start = rfile.first_byte_timestamp
return response.Response(http_version, status_code, message, headers, None, timestamp_start)
return response.Response(http_version, status_code, message, headers, None, None, timestamp_start)
def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):

View File

@ -28,6 +28,8 @@ class MessageData(serializable.Serializable):
def get_state(self):
state = vars(self).copy()
state["headers"] = state["headers"].get_state()
if 'trailers' in state and state["trailers"] is not None:
state["trailers"] = state["trailers"].get_state()
return state
@classmethod
@ -53,6 +55,8 @@ class Message(serializable.Serializable):
@classmethod
def from_state(cls, state):
state["headers"] = mheaders.Headers.from_state(state["headers"])
if 'trailers' in state and state["trailers"] is not None:
state["trailers"] = mheaders.Headers.from_state(state["trailers"])
return cls(**state)
@property
@ -130,6 +134,20 @@ class Message(serializable.Serializable):
content = property(get_content, set_content)
@property
def trailers(self):
"""
Message trailers object
Returns:
mitmproxy.net.http.Headers
"""
return self.data.trailers
@trailers.setter
def trailers(self, h):
self.data.trailers = h
@property
def http_version(self):
"""

View File

@ -29,6 +29,7 @@ class RequestData(message.MessageData):
http_version,
headers=(),
content=None,
trailers=None,
timestamp_start=None,
timestamp_end=None
):
@ -46,6 +47,8 @@ class RequestData(message.MessageData):
headers = nheaders.Headers(headers)
if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
if trailers is not None and not isinstance(trailers, nheaders.Headers):
trailers = nheaders.Headers(trailers)
self.first_line_format = first_line_format
self.method = method
@ -56,6 +59,7 @@ class RequestData(message.MessageData):
self.http_version = http_version
self.headers = headers
self.content = content
self.trailers = trailers
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end

View File

@ -22,6 +22,7 @@ class ResponseData(message.MessageData):
reason=None,
headers=(),
content=None,
trailers=None,
timestamp_start=None,
timestamp_end=None
):
@ -33,12 +34,15 @@ class ResponseData(message.MessageData):
headers = nheaders.Headers(headers)
if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
if trailers is not None and not isinstance(trailers, nheaders.Headers):
trailers = nheaders.Headers(trailers)
self.http_version = http_version
self.status_code = status_code
self.reason = reason
self.headers = headers
self.content = content
self.trailers = trailers
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end

View File

@ -20,6 +20,9 @@ class _HttpTransmissionLayer(base.Layer):
def read_request_body(self, request):
raise NotImplementedError()
def read_request_trailers(self, request):
raise NotImplementedError()
def send_request(self, request):
raise NotImplementedError()
@ -30,11 +33,15 @@ class _HttpTransmissionLayer(base.Layer):
raise NotImplementedError()
yield "this is a generator" # pragma: no cover
def read_response_trailers(self, request, response):
raise NotImplementedError()
def read_response(self, request):
response = self.read_response_headers()
response.data.content = b"".join(
self.read_response_body(request, response)
)
response.data.trailers = self.read_response_trailers(request, response)
return response
def send_response(self, response):
@ -42,6 +49,7 @@ class _HttpTransmissionLayer(base.Layer):
raise exceptions.HttpException("Cannot assemble flow with missing content")
self.send_response_headers(response)
self.send_response_body(response, [response.data.content])
self.send_response_trailers(response)
def send_response_headers(self, response):
raise NotImplementedError()
@ -49,6 +57,9 @@ class _HttpTransmissionLayer(base.Layer):
def send_response_body(self, response, chunks):
raise NotImplementedError()
def send_response_trailers(self, response, chunks):
raise NotImplementedError()
def check_close_connection(self, f):
raise NotImplementedError()
@ -255,6 +266,7 @@ class HttpLayer(base.Layer):
f.request.data.content = b"".join(
self.read_request_body(f.request)
)
f.request.data.trailers = self.read_request_trailers(f.request)
f.request.timestamp_end = time.time()
self.channel.ask("http_connect", f)
@ -282,6 +294,9 @@ class HttpLayer(base.Layer):
f.request.data.content = None
else:
f.request.data.content = b"".join(self.read_request_body(request))
f.request.data.trailers = self.read_request_trailers(f.request)
request.timestamp_end = time.time()
except exceptions.HttpException as e:
# We optimistically guess there might be an HTTP client on the
@ -348,6 +363,8 @@ class HttpLayer(base.Layer):
else:
self.send_request_body(f.request, [f.request.data.content])
self.send_request_trailers(f.request)
f.response = self.read_response_headers()
try:
@ -406,6 +423,8 @@ class HttpLayer(base.Layer):
# we now need to emulate the responseheaders hook.
self.channel.ask("responseheaders", f)
f.response.data.trailers = self.read_response_trailers(f.request, f.response)
self.log("response", "debug", [repr(f.response)])
self.channel.ask("response", f)

View File

@ -23,6 +23,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
human.parse_size(self.config.options.body_size_limit)
)
def read_request_trailers(self, request):
if "Trailer" in request.headers:
# TODO: not implemented yet
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
return None
def send_request_headers(self, request):
headers = http1.assemble_request_head(request)
self.server_conn.wfile.write(headers)
@ -33,7 +39,13 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
self.server_conn.wfile.write(chunk)
self.server_conn.wfile.flush()
def send_request_trailers(self, request):
if "Trailer" in request.headers:
# TODO: not implemented yet
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
def send_request(self, request):
# TODO: this does not yet support request trailers
self.server_conn.wfile.write(http1.assemble_request(request))
self.server_conn.wfile.flush()
@ -49,6 +61,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
human.parse_size(self.config.options.body_size_limit)
)
def read_response_trailers(self, request, response):
if "Trailer" in response.headers:
# TODO: not implemented yet
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
return None
def send_response_headers(self, response):
raw = http1.assemble_response_head(response)
self.client_conn.wfile.write(raw)
@ -59,6 +77,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
self.client_conn.wfile.write(chunk)
self.client_conn.wfile.flush()
def send_response_trailers(self, response):
if "Trailer" in response.headers:
# TODO: not implemented yet
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
return
def check_close_connection(self, flow):
request_close = http1.connection_close(
flow.request.http_version,

View File

@ -55,7 +55,7 @@ class SafeH2Connection(connection.H2Connection):
self.send_headers(stream_id, headers.fields, **kwargs)
self.conn.send(self.data_to_send())
def safe_send_body(self, raise_zombie: Callable, stream_id: int, chunks: List[bytes]):
def safe_send_body(self, raise_zombie: Callable, stream_id: int, chunks: List[bytes], end_stream=True):
for chunk in chunks:
position = 0
while position < len(chunk):
@ -75,6 +75,7 @@ class SafeH2Connection(connection.H2Connection):
finally:
self.lock.release()
position += max_outbound_frame_size
if end_stream:
with self.lock:
raise_zombie()
self.end_stream(stream_id)
@ -170,7 +171,7 @@ class Http2Layer(base.Layer):
elif isinstance(event, events.PriorityUpdated):
return self._handle_priority_updated(eid, event)
elif isinstance(event, events.TrailersReceived):
raise NotImplementedError('TrailersReceived not implemented')
return self._handle_trailers(eid, event, is_server, other_conn)
# fail-safe for unhandled events
return True
@ -179,22 +180,21 @@ class Http2Layer(base.Layer):
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers)
self.streams[eid].timestamp_start = time.time()
self.streams[eid].no_body = (event.stream_ended is not None)
if event.priority_updated is not None:
self.streams[eid].priority_exclusive = event.priority_updated.exclusive
self.streams[eid].priority_depends_on = event.priority_updated.depends_on
self.streams[eid].priority_weight = event.priority_updated.weight
self.streams[eid].handled_priority_event = event.priority_updated
self.streams[eid].start()
self.streams[eid].request_arrived.set()
self.streams[eid].request_message.arrived.set()
return True
def _handle_response_received(self, eid, event):
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid].queued_data_length = 0
self.streams[eid].timestamp_start = time.time()
self.streams[eid].response_headers = headers
self.streams[eid].response_arrived.set()
self.streams[eid].response_message.headers = headers
self.streams[eid].response_message.arrived.set()
return True
def _handle_data_received(self, eid, event, source_conn):
@ -219,7 +219,7 @@ class Http2Layer(base.Layer):
def _handle_stream_ended(self, eid):
self.streams[eid].timestamp_end = time.time()
self.streams[eid].data_finished.set()
self.streams[eid].stream_ended.set()
return True
def _handle_stream_reset(self, eid, event, is_server, other_conn):
@ -233,6 +233,11 @@ class Http2Layer(base.Layer):
self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code)
return True
def _handle_trailers(self, eid, event, is_server, other_conn):
trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid].trailers = trailers
return True
def _handle_remote_settings_changed(self, event, other_conn):
new_settings = dict([(key, cs.new_value) for (key, cs) in event.changed_settings.items()])
self.connections[other_conn].safe_update_settings(new_settings)
@ -277,8 +282,8 @@ class Http2Layer(base.Layer):
self.streams[event.pushed_stream_id].pushed = True
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
self.streams[event.pushed_stream_id].timestamp_end = time.time()
self.streams[event.pushed_stream_id].request_arrived.set()
self.streams[event.pushed_stream_id].request_data_finished.set()
self.streams[event.pushed_stream_id].request_message.arrived.set()
self.streams[event.pushed_stream_id].request_message.stream_ended.set()
self.streams[event.pushed_stream_id].start()
return True
@ -392,6 +397,16 @@ def detect_zombie_stream(func): # pragma: no cover
class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThread):
class Message:
def __init__(self, headers=None):
self.headers: Optional[mitmproxy.net.http.Headers] = headers # headers are the first thing to be received on a new stream
self.data_queue: queue.Queue[bytes] = queue.Queue() # contains raw contents of DATA frames
self.queued_data_length = 0 # used to enforce mitmproxy's config.options.body_size_limit
self.trailers: Optional[mitmproxy.net.http.Headers] = None # trailers are received after stream_ended is set
self.arrived = threading.Event() # indicates the HEADERS+CONTINUTATION frames have been received
self.stream_ended = threading.Event() # indicates the a frame with the END_STREAM flag has been received
def __init__(self, ctx, h2_connection, stream_id: int, request_headers: mitmproxy.net.http.Headers) -> None:
super().__init__(
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
@ -400,24 +415,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.zombie: Optional[float] = None
self.client_stream_id: int = stream_id
self.server_stream_id: Optional[int] = None
self.request_headers = request_headers
self.response_headers: Optional[mitmproxy.net.http.Headers] = None
self.pushed = False
self.timestamp_start: Optional[float] = None
self.timestamp_end: Optional[float] = None
self.request_arrived = threading.Event()
self.request_data_queue: queue.Queue[bytes] = queue.Queue()
self.request_queued_data_length = 0
self.request_data_finished = threading.Event()
self.response_arrived = threading.Event()
self.response_data_queue: queue.Queue[bytes] = queue.Queue()
self.response_queued_data_length = 0
self.response_data_finished = threading.Event()
self.no_body = False
self.request_message = self.Message(request_headers)
self.response_message = self.Message()
self.priority_exclusive: bool
self.priority_depends_on: Optional[int] = None
@ -427,10 +431,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
def kill(self):
if not self.zombie:
self.zombie = time.time()
self.request_data_finished.set()
self.request_arrived.set()
self.response_arrived.set()
self.response_data_finished.set()
self.request_message.stream_ended.set()
self.request_message.arrived.set()
self.response_message.arrived.set()
self.response_message.stream_ended.set()
def connect(self): # pragma: no cover
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
@ -448,28 +452,44 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@property
def data_queue(self):
if self.response_arrived.is_set():
return self.response_data_queue
if self.response_message.arrived.is_set():
return self.response_message.data_queue
else:
return self.request_data_queue
return self.request_message.data_queue
@property
def queued_data_length(self):
if self.response_arrived.is_set():
return self.response_queued_data_length
if self.response_message.arrived.is_set():
return self.response_message.queued_data_length
else:
return self.request_queued_data_length
return self.request_message.queued_data_length
@queued_data_length.setter
def queued_data_length(self, v):
self.request_queued_data_length = v
self.request_message.queued_data_length = v
@property
def data_finished(self):
if self.response_arrived.is_set():
return self.response_data_finished
def stream_ended(self):
# This indicates that all message headers, the full message body, and all trailers have been received
# https://tools.ietf.org/html/rfc7540#section-8.1
if self.response_message.arrived.is_set():
return self.response_message.stream_ended
else:
return self.request_data_finished
return self.request_message.stream_ended
@property
def trailers(self):
if self.response_message.arrived.is_set():
return self.response_message.trailers
else:
return self.request_message.trailers
@trailers.setter
def trailers(self, v):
if self.response_message.arrived.is_set():
self.response_message.trailers = v
else:
self.request_message.trailers = v
def raise_zombie(self, pre_command=None): # pragma: no cover
connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED
@ -480,13 +500,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@detect_zombie_stream
def read_request_headers(self, flow):
self.request_arrived.wait()
self.request_message.arrived.wait()
self.raise_zombie()
if self.pushed:
flow.metadata['h2-pushed-stream'] = True
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers)
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_message.headers)
return http.HTTPRequest(
first_line_format,
method,
@ -495,7 +515,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
port,
path,
b"HTTP/2.0",
self.request_headers,
self.request_message.headers,
None,
timestamp_start=self.timestamp_start,
timestamp_end=self.timestamp_end,
@ -504,20 +524,24 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@detect_zombie_stream
def read_request_body(self, request):
if not request.stream:
self.request_data_finished.wait()
self.request_message.stream_ended.wait()
while True:
try:
yield self.request_data_queue.get(timeout=0.1)
yield self.request_message.data_queue.get(timeout=0.1)
except queue.Empty: # pragma: no cover
pass
if self.request_data_finished.is_set():
if self.request_message.stream_ended.is_set():
self.raise_zombie()
while self.request_data_queue.qsize() > 0:
yield self.request_data_queue.get()
while self.request_message.data_queue.qsize() > 0:
yield self.request_message.data_queue.get()
break
self.raise_zombie()
@detect_zombie_stream
def read_request_trailers(self, request):
return self.request_message.trailers
@detect_zombie_stream
def send_request_headers(self, request):
if self.pushed:
@ -567,7 +591,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.raise_zombie,
self.server_stream_id,
headers,
end_stream=self.no_body,
priority_exclusive=priority_exclusive,
priority_depends_on=priority_depends_on,
priority_weight=priority_weight,
@ -584,26 +607,31 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
# nothing to do here
return
if not self.no_body:
self.connections[self.server_conn].safe_send_body(
self.raise_zombie,
self.server_stream_id,
chunks
chunks,
end_stream=(request.trailers is None),
)
@detect_zombie_stream
def send_request(self, message):
self.send_request_headers(message)
self.send_request_body(message, [message.content])
def send_request_trailers(self, request):
self._send_trailers(self.server_conn, request.trailers)
@detect_zombie_stream
def send_request(self, request):
self.send_request_headers(request)
self.send_request_body(request, [request.content])
self.send_request_trailers(request)
@detect_zombie_stream
def read_response_headers(self):
self.response_arrived.wait()
self.response_message.arrived.wait()
self.raise_zombie()
status_code = int(self.response_headers.get(':status', 502))
headers = self.response_headers.copy()
status_code = int(self.response_message.headers.get(':status', 502))
headers = self.response_message.headers.copy()
headers.pop(":status", None)
return http.HTTPResponse(
@ -620,16 +648,20 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
def read_response_body(self, request, response):
while True:
try:
yield self.response_data_queue.get(timeout=0.1)
yield self.response_message.data_queue.get(timeout=0.1)
except queue.Empty: # pragma: no cover
pass
if self.response_data_finished.is_set():
if self.response_message.stream_ended.is_set():
self.raise_zombie()
while self.response_data_queue.qsize() > 0:
yield self.response_data_queue.get()
while self.response_message.data_queue.qsize() > 0:
yield self.response_message.data_queue.get()
break
self.raise_zombie()
@detect_zombie_stream
def read_response_trailers(self, request, response):
return self.response_message.trailers
@detect_zombie_stream
def send_response_headers(self, response):
headers = response.headers.copy()
@ -642,11 +674,27 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
)
@detect_zombie_stream
def send_response_body(self, _response, chunks):
def send_response_body(self, response, chunks):
self.connections[self.client_conn].safe_send_body(
self.raise_zombie,
self.client_stream_id,
chunks
chunks,
end_stream=(response.trailers is None),
)
@detect_zombie_stream
def send_response_trailers(self, response):
self._send_trailers(self.client_conn, response.trailers)
def _send_trailers(self, conn, trailers):
if not trailers:
return
with self.connections[conn].lock:
self.connections[conn].safe_send_headers(
self.raise_zombie,
self.client_stream_id,
trailers,
end_stream=True
)
def __call__(self): # pragma: no cover

View File

@ -93,6 +93,9 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
"timestamp_end": flow.response.timestamp_end,
"is_replay": flow.response.is_replay,
}
if flow.response.data.trailers:
f["response"]["trailers"] = tuple(flow.response.data.trailers.items(True))
f.get("server_conn", {}).pop("cert", None)
f.get("client_conn", {}).pop("mitmcert", None)

View File

@ -8,7 +8,7 @@ MITMPROXY = "mitmproxy " + VERSION
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
# for each change in the file format.
FLOW_FORMAT_VERSION = 7
FLOW_FORMAT_VERSION = 8
def get_dev_version() -> str:

View File

@ -110,8 +110,9 @@ class HTTP2StateProtocol:
b"HTTP/2.0",
headers,
body,
timestamp_start,
timestamp_end,
None,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
request.stream_id = stream_id

View File

@ -21,8 +21,11 @@ class TestRequestData:
treq(headers="foobar")
with pytest.raises(ValueError):
treq(content="foobar")
with pytest.raises(ValueError):
treq(trailers="foobar")
assert isinstance(treq(headers=()).headers, Headers)
assert isinstance(treq(trailers=()).trailers, Headers)
class TestRequestCore:

View File

@ -20,8 +20,11 @@ class TestResponseData:
tresp(reason="fööbär")
with pytest.raises(ValueError):
tresp(content="foobar")
with pytest.raises(ValueError):
tresp(trailers="foobar")
assert isinstance(tresp(headers=()).headers, Headers)
assert isinstance(tresp(trailers=()).trailers, Headers)
class TestResponseCore:

View File

@ -1031,3 +1031,147 @@ class TestResponseStreaming(_Http2Test):
assert data
else:
assert data is None
class TestRequestTrailers(_Http2Test):
server_trailers_received = False
@classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.RequestReceived):
# reset the value for a fresh test
cls.server_trailers_received = False
elif isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.TrailersReceived):
cls.server_trailers_received = True
elif isinstance(event, h2.events.StreamEnded):
h2_conn.send_headers(event.stream_id, [
(':status', '200'),
('x-my-trailer-request-received', 'success' if cls.server_trailers_received else "failure"),
], end_stream=True)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
@pytest.mark.parametrize('announce', [True, False])
@pytest.mark.parametrize('body', [None, b"foobar"])
def test_trailers(self, announce, body):
h2_conn = self.setup_connection()
stream_id = 1
headers = [
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
]
if announce:
headers.append(('trailer', 'x-my-trailers'))
h2_conn.send_headers(
stream_id=stream_id,
headers=headers,
)
if body:
h2_conn.send_data(stream_id, body)
# send trailers
h2_conn.send_headers(stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
done = False
while not done:
try:
raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamEnded):
done = True
h2_conn.close_connection()
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].request.trailers['x-my-trailers'] == 'foobar'
assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.headers['x-my-trailer-request-received'] == 'success'
class TestResponseTrailers(_Http2Test):
@classmethod
def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.StreamEnded):
headers = [
(':status', '200'),
]
if event.stream_id == 1:
# special stream_id to activate the Trailer announcement header
headers.append(('trailer', 'x-my-trailers'))
h2_conn.send_headers(event.stream_id, headers)
h2_conn.send_data(event.stream_id, b'response body')
h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
wfile.write(h2_conn.data_to_send())
wfile.flush()
return True
@pytest.mark.parametrize('announce', [True, False])
def test_trailers(self, announce):
response_body_buffer = b''
h2_conn = self.setup_connection()
self._send_request(
self.client.wfile,
h2_conn,
stream_id=(1 if announce else 3),
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
])
trailers_buffer = None
done = False
while not done:
try:
raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.DataReceived):
response_body_buffer += event.data
elif isinstance(event, h2.events.TrailersReceived):
trailers_buffer = event.headers
elif isinstance(event, h2.events.StreamEnded):
done = True
h2_conn.close_connection()
self.client.wfile.write(h2_conn.data_to_send())
self.client.wfile.flush()
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.content == b'response body'
assert response_body_buffer == b'response body'
assert self.master.state.flows[0].response.trailers['x-my-trailers'] == 'foobar'
assert trailers_buffer == [(b'x-my-trailers', b'foobar')]