mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-29 19:08:44 +00:00
Merge pull request #4042 from sanlengjingvv/develop
support HTTP/2 trailers
This commit is contained in:
commit
46a0f69485
28
examples/addons/http-trailers.py
Normal file
28
examples/addons/http-trailers.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
This script simply prints all received HTTP Trailers.
|
||||||
|
|
||||||
|
HTTP requests and responses can container trailing headers which are sent after
|
||||||
|
the body is fully transmitted. Such trailers need to be announced in the initial
|
||||||
|
headers by name, so the receiving endpoint can wait and read them after the
|
||||||
|
body.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from mitmproxy import http
|
||||||
|
from mitmproxy.net.http import Headers
|
||||||
|
|
||||||
|
|
||||||
|
def request(flow: http.HTTPFlow):
|
||||||
|
if flow.request.trailers:
|
||||||
|
print("HTTP Trailers detected! Request contains:", flow.request.trailers)
|
||||||
|
|
||||||
|
|
||||||
|
def response(flow: http.HTTPFlow):
|
||||||
|
if flow.response.trailers:
|
||||||
|
print("HTTP Trailers detected! Response contains:", flow.response.trailers)
|
||||||
|
|
||||||
|
if flow.request.path == "/inject_trailers":
|
||||||
|
flow.response.headers["trailer"] = "x-my-injected-trailer-header"
|
||||||
|
flow.response.trailers = Headers([
|
||||||
|
(b"x-my-injected-trailer-header", b"foobar")
|
||||||
|
])
|
||||||
|
print("Injected a new trailer...", flow.response.headers["trailer"])
|
@ -26,6 +26,7 @@ class HTTPRequest(http.Request):
|
|||||||
http_version,
|
http_version,
|
||||||
headers,
|
headers,
|
||||||
content,
|
content,
|
||||||
|
trailers=None,
|
||||||
timestamp_start=None,
|
timestamp_start=None,
|
||||||
timestamp_end=None,
|
timestamp_end=None,
|
||||||
is_replay=False,
|
is_replay=False,
|
||||||
@ -41,6 +42,7 @@ class HTTPRequest(http.Request):
|
|||||||
http_version,
|
http_version,
|
||||||
headers,
|
headers,
|
||||||
content,
|
content,
|
||||||
|
trailers,
|
||||||
timestamp_start,
|
timestamp_start,
|
||||||
timestamp_end,
|
timestamp_end,
|
||||||
)
|
)
|
||||||
@ -73,6 +75,7 @@ class HTTPRequest(http.Request):
|
|||||||
http_version=request.data.http_version,
|
http_version=request.data.http_version,
|
||||||
headers=request.data.headers,
|
headers=request.data.headers,
|
||||||
content=request.data.content,
|
content=request.data.content,
|
||||||
|
trailers=request.data.trailers,
|
||||||
timestamp_start=request.data.timestamp_start,
|
timestamp_start=request.data.timestamp_start,
|
||||||
timestamp_end=request.data.timestamp_end,
|
timestamp_end=request.data.timestamp_end,
|
||||||
)
|
)
|
||||||
@ -97,6 +100,7 @@ class HTTPResponse(http.Response):
|
|||||||
reason,
|
reason,
|
||||||
headers,
|
headers,
|
||||||
content,
|
content,
|
||||||
|
trailers=None,
|
||||||
timestamp_start=None,
|
timestamp_start=None,
|
||||||
timestamp_end=None,
|
timestamp_end=None,
|
||||||
is_replay=False
|
is_replay=False
|
||||||
@ -108,6 +112,7 @@ class HTTPResponse(http.Response):
|
|||||||
reason,
|
reason,
|
||||||
headers,
|
headers,
|
||||||
content,
|
content,
|
||||||
|
trailers,
|
||||||
timestamp_start=timestamp_start,
|
timestamp_start=timestamp_start,
|
||||||
timestamp_end=timestamp_end,
|
timestamp_end=timestamp_end,
|
||||||
)
|
)
|
||||||
@ -127,6 +132,7 @@ class HTTPResponse(http.Response):
|
|||||||
reason=response.data.reason,
|
reason=response.data.reason,
|
||||||
headers=response.data.headers,
|
headers=response.data.headers,
|
||||||
content=response.data.content,
|
content=response.data.content,
|
||||||
|
trailers=response.data.trailers,
|
||||||
timestamp_start=response.data.timestamp_start,
|
timestamp_start=response.data.timestamp_start,
|
||||||
timestamp_end=response.data.timestamp_end,
|
timestamp_end=response.data.timestamp_end,
|
||||||
)
|
)
|
||||||
|
@ -172,6 +172,13 @@ def convert_6_7(data):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def convert_7_8(data):
|
||||||
|
data["version"] = 8
|
||||||
|
data["request"]["trailers"] = None
|
||||||
|
data["response"]["trailers"] = None
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_keys(o: Any) -> Any:
|
def _convert_dict_keys(o: Any) -> Any:
|
||||||
if isinstance(o, dict):
|
if isinstance(o, dict):
|
||||||
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
|
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
|
||||||
@ -226,6 +233,7 @@ converters = {
|
|||||||
4: convert_4_5,
|
4: convert_4_5,
|
||||||
5: convert_5_6,
|
5: convert_5_6,
|
||||||
6: convert_6_7,
|
6: convert_6_7,
|
||||||
|
7: convert_7_8,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ def read_request_head(rfile):
|
|||||||
timestamp_start = rfile.first_byte_timestamp
|
timestamp_start = rfile.first_byte_timestamp
|
||||||
|
|
||||||
return request.Request(
|
return request.Request(
|
||||||
form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
|
form, method, scheme, host, port, path, http_version, headers, None, None, timestamp_start
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ def read_response_head(rfile):
|
|||||||
# more accurate timestamp_start
|
# more accurate timestamp_start
|
||||||
timestamp_start = rfile.first_byte_timestamp
|
timestamp_start = rfile.first_byte_timestamp
|
||||||
|
|
||||||
return response.Response(http_version, status_code, message, headers, None, timestamp_start)
|
return response.Response(http_version, status_code, message, headers, None, None, timestamp_start)
|
||||||
|
|
||||||
|
|
||||||
def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
|
def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
|
||||||
|
@ -28,6 +28,8 @@ class MessageData(serializable.Serializable):
|
|||||||
def get_state(self):
|
def get_state(self):
|
||||||
state = vars(self).copy()
|
state = vars(self).copy()
|
||||||
state["headers"] = state["headers"].get_state()
|
state["headers"] = state["headers"].get_state()
|
||||||
|
if 'trailers' in state and state["trailers"] is not None:
|
||||||
|
state["trailers"] = state["trailers"].get_state()
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -53,6 +55,8 @@ class Message(serializable.Serializable):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_state(cls, state):
|
def from_state(cls, state):
|
||||||
state["headers"] = mheaders.Headers.from_state(state["headers"])
|
state["headers"] = mheaders.Headers.from_state(state["headers"])
|
||||||
|
if 'trailers' in state and state["trailers"] is not None:
|
||||||
|
state["trailers"] = mheaders.Headers.from_state(state["trailers"])
|
||||||
return cls(**state)
|
return cls(**state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -130,6 +134,20 @@ class Message(serializable.Serializable):
|
|||||||
|
|
||||||
content = property(get_content, set_content)
|
content = property(get_content, set_content)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trailers(self):
|
||||||
|
"""
|
||||||
|
Message trailers object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mitmproxy.net.http.Headers
|
||||||
|
"""
|
||||||
|
return self.data.trailers
|
||||||
|
|
||||||
|
@trailers.setter
|
||||||
|
def trailers(self, h):
|
||||||
|
self.data.trailers = h
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def http_version(self):
|
def http_version(self):
|
||||||
"""
|
"""
|
||||||
|
@ -29,6 +29,7 @@ class RequestData(message.MessageData):
|
|||||||
http_version,
|
http_version,
|
||||||
headers=(),
|
headers=(),
|
||||||
content=None,
|
content=None,
|
||||||
|
trailers=None,
|
||||||
timestamp_start=None,
|
timestamp_start=None,
|
||||||
timestamp_end=None
|
timestamp_end=None
|
||||||
):
|
):
|
||||||
@ -46,6 +47,8 @@ class RequestData(message.MessageData):
|
|||||||
headers = nheaders.Headers(headers)
|
headers = nheaders.Headers(headers)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
||||||
|
if trailers is not None and not isinstance(trailers, nheaders.Headers):
|
||||||
|
trailers = nheaders.Headers(trailers)
|
||||||
|
|
||||||
self.first_line_format = first_line_format
|
self.first_line_format = first_line_format
|
||||||
self.method = method
|
self.method = method
|
||||||
@ -56,6 +59,7 @@ class RequestData(message.MessageData):
|
|||||||
self.http_version = http_version
|
self.http_version = http_version
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.content = content
|
self.content = content
|
||||||
|
self.trailers = trailers
|
||||||
self.timestamp_start = timestamp_start
|
self.timestamp_start = timestamp_start
|
||||||
self.timestamp_end = timestamp_end
|
self.timestamp_end = timestamp_end
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ class ResponseData(message.MessageData):
|
|||||||
reason=None,
|
reason=None,
|
||||||
headers=(),
|
headers=(),
|
||||||
content=None,
|
content=None,
|
||||||
|
trailers=None,
|
||||||
timestamp_start=None,
|
timestamp_start=None,
|
||||||
timestamp_end=None
|
timestamp_end=None
|
||||||
):
|
):
|
||||||
@ -33,12 +34,15 @@ class ResponseData(message.MessageData):
|
|||||||
headers = nheaders.Headers(headers)
|
headers = nheaders.Headers(headers)
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
||||||
|
if trailers is not None and not isinstance(trailers, nheaders.Headers):
|
||||||
|
trailers = nheaders.Headers(trailers)
|
||||||
|
|
||||||
self.http_version = http_version
|
self.http_version = http_version
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.reason = reason
|
self.reason = reason
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.content = content
|
self.content = content
|
||||||
|
self.trailers = trailers
|
||||||
self.timestamp_start = timestamp_start
|
self.timestamp_start = timestamp_start
|
||||||
self.timestamp_end = timestamp_end
|
self.timestamp_end = timestamp_end
|
||||||
|
|
||||||
|
@ -20,6 +20,9 @@ class _HttpTransmissionLayer(base.Layer):
|
|||||||
def read_request_body(self, request):
|
def read_request_body(self, request):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def read_request_trailers(self, request):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def send_request(self, request):
|
def send_request(self, request):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -30,11 +33,15 @@ class _HttpTransmissionLayer(base.Layer):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
yield "this is a generator" # pragma: no cover
|
yield "this is a generator" # pragma: no cover
|
||||||
|
|
||||||
|
def read_response_trailers(self, request, response):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def read_response(self, request):
|
def read_response(self, request):
|
||||||
response = self.read_response_headers()
|
response = self.read_response_headers()
|
||||||
response.data.content = b"".join(
|
response.data.content = b"".join(
|
||||||
self.read_response_body(request, response)
|
self.read_response_body(request, response)
|
||||||
)
|
)
|
||||||
|
response.data.trailers = self.read_response_trailers(request, response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def send_response(self, response):
|
def send_response(self, response):
|
||||||
@ -42,6 +49,7 @@ class _HttpTransmissionLayer(base.Layer):
|
|||||||
raise exceptions.HttpException("Cannot assemble flow with missing content")
|
raise exceptions.HttpException("Cannot assemble flow with missing content")
|
||||||
self.send_response_headers(response)
|
self.send_response_headers(response)
|
||||||
self.send_response_body(response, [response.data.content])
|
self.send_response_body(response, [response.data.content])
|
||||||
|
self.send_response_trailers(response)
|
||||||
|
|
||||||
def send_response_headers(self, response):
|
def send_response_headers(self, response):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -49,6 +57,9 @@ class _HttpTransmissionLayer(base.Layer):
|
|||||||
def send_response_body(self, response, chunks):
|
def send_response_body(self, response, chunks):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def send_response_trailers(self, response, chunks):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def check_close_connection(self, f):
|
def check_close_connection(self, f):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -255,6 +266,7 @@ class HttpLayer(base.Layer):
|
|||||||
f.request.data.content = b"".join(
|
f.request.data.content = b"".join(
|
||||||
self.read_request_body(f.request)
|
self.read_request_body(f.request)
|
||||||
)
|
)
|
||||||
|
f.request.data.trailers = self.read_request_trailers(f.request)
|
||||||
f.request.timestamp_end = time.time()
|
f.request.timestamp_end = time.time()
|
||||||
self.channel.ask("http_connect", f)
|
self.channel.ask("http_connect", f)
|
||||||
|
|
||||||
@ -282,6 +294,9 @@ class HttpLayer(base.Layer):
|
|||||||
f.request.data.content = None
|
f.request.data.content = None
|
||||||
else:
|
else:
|
||||||
f.request.data.content = b"".join(self.read_request_body(request))
|
f.request.data.content = b"".join(self.read_request_body(request))
|
||||||
|
|
||||||
|
f.request.data.trailers = self.read_request_trailers(f.request)
|
||||||
|
|
||||||
request.timestamp_end = time.time()
|
request.timestamp_end = time.time()
|
||||||
except exceptions.HttpException as e:
|
except exceptions.HttpException as e:
|
||||||
# We optimistically guess there might be an HTTP client on the
|
# We optimistically guess there might be an HTTP client on the
|
||||||
@ -348,6 +363,8 @@ class HttpLayer(base.Layer):
|
|||||||
else:
|
else:
|
||||||
self.send_request_body(f.request, [f.request.data.content])
|
self.send_request_body(f.request, [f.request.data.content])
|
||||||
|
|
||||||
|
self.send_request_trailers(f.request)
|
||||||
|
|
||||||
f.response = self.read_response_headers()
|
f.response = self.read_response_headers()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -406,6 +423,8 @@ class HttpLayer(base.Layer):
|
|||||||
# we now need to emulate the responseheaders hook.
|
# we now need to emulate the responseheaders hook.
|
||||||
self.channel.ask("responseheaders", f)
|
self.channel.ask("responseheaders", f)
|
||||||
|
|
||||||
|
f.response.data.trailers = self.read_response_trailers(f.request, f.response)
|
||||||
|
|
||||||
self.log("response", "debug", [repr(f.response)])
|
self.log("response", "debug", [repr(f.response)])
|
||||||
self.channel.ask("response", f)
|
self.channel.ask("response", f)
|
||||||
|
|
||||||
|
@ -23,6 +23,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
|||||||
human.parse_size(self.config.options.body_size_limit)
|
human.parse_size(self.config.options.body_size_limit)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def read_request_trailers(self, request):
|
||||||
|
if "Trailer" in request.headers:
|
||||||
|
# TODO: not implemented yet
|
||||||
|
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
|
||||||
|
return None
|
||||||
|
|
||||||
def send_request_headers(self, request):
|
def send_request_headers(self, request):
|
||||||
headers = http1.assemble_request_head(request)
|
headers = http1.assemble_request_head(request)
|
||||||
self.server_conn.wfile.write(headers)
|
self.server_conn.wfile.write(headers)
|
||||||
@ -33,7 +39,13 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
|||||||
self.server_conn.wfile.write(chunk)
|
self.server_conn.wfile.write(chunk)
|
||||||
self.server_conn.wfile.flush()
|
self.server_conn.wfile.flush()
|
||||||
|
|
||||||
|
def send_request_trailers(self, request):
|
||||||
|
if "Trailer" in request.headers:
|
||||||
|
# TODO: not implemented yet
|
||||||
|
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
|
||||||
|
|
||||||
def send_request(self, request):
|
def send_request(self, request):
|
||||||
|
# TODO: this does not yet support request trailers
|
||||||
self.server_conn.wfile.write(http1.assemble_request(request))
|
self.server_conn.wfile.write(http1.assemble_request(request))
|
||||||
self.server_conn.wfile.flush()
|
self.server_conn.wfile.flush()
|
||||||
|
|
||||||
@ -49,6 +61,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
|||||||
human.parse_size(self.config.options.body_size_limit)
|
human.parse_size(self.config.options.body_size_limit)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def read_response_trailers(self, request, response):
|
||||||
|
if "Trailer" in response.headers:
|
||||||
|
# TODO: not implemented yet
|
||||||
|
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
|
||||||
|
return None
|
||||||
|
|
||||||
def send_response_headers(self, response):
|
def send_response_headers(self, response):
|
||||||
raw = http1.assemble_response_head(response)
|
raw = http1.assemble_response_head(response)
|
||||||
self.client_conn.wfile.write(raw)
|
self.client_conn.wfile.write(raw)
|
||||||
@ -59,6 +77,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
|||||||
self.client_conn.wfile.write(chunk)
|
self.client_conn.wfile.write(chunk)
|
||||||
self.client_conn.wfile.flush()
|
self.client_conn.wfile.flush()
|
||||||
|
|
||||||
|
def send_response_trailers(self, response):
|
||||||
|
if "Trailer" in response.headers:
|
||||||
|
# TODO: not implemented yet
|
||||||
|
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
|
||||||
|
return
|
||||||
|
|
||||||
def check_close_connection(self, flow):
|
def check_close_connection(self, flow):
|
||||||
request_close = http1.connection_close(
|
request_close = http1.connection_close(
|
||||||
flow.request.http_version,
|
flow.request.http_version,
|
||||||
|
@ -55,7 +55,7 @@ class SafeH2Connection(connection.H2Connection):
|
|||||||
self.send_headers(stream_id, headers.fields, **kwargs)
|
self.send_headers(stream_id, headers.fields, **kwargs)
|
||||||
self.conn.send(self.data_to_send())
|
self.conn.send(self.data_to_send())
|
||||||
|
|
||||||
def safe_send_body(self, raise_zombie: Callable, stream_id: int, chunks: List[bytes]):
|
def safe_send_body(self, raise_zombie: Callable, stream_id: int, chunks: List[bytes], end_stream=True):
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
position = 0
|
position = 0
|
||||||
while position < len(chunk):
|
while position < len(chunk):
|
||||||
@ -75,6 +75,7 @@ class SafeH2Connection(connection.H2Connection):
|
|||||||
finally:
|
finally:
|
||||||
self.lock.release()
|
self.lock.release()
|
||||||
position += max_outbound_frame_size
|
position += max_outbound_frame_size
|
||||||
|
if end_stream:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
raise_zombie()
|
raise_zombie()
|
||||||
self.end_stream(stream_id)
|
self.end_stream(stream_id)
|
||||||
@ -170,7 +171,7 @@ class Http2Layer(base.Layer):
|
|||||||
elif isinstance(event, events.PriorityUpdated):
|
elif isinstance(event, events.PriorityUpdated):
|
||||||
return self._handle_priority_updated(eid, event)
|
return self._handle_priority_updated(eid, event)
|
||||||
elif isinstance(event, events.TrailersReceived):
|
elif isinstance(event, events.TrailersReceived):
|
||||||
raise NotImplementedError('TrailersReceived not implemented')
|
return self._handle_trailers(eid, event, is_server, other_conn)
|
||||||
|
|
||||||
# fail-safe for unhandled events
|
# fail-safe for unhandled events
|
||||||
return True
|
return True
|
||||||
@ -179,22 +180,21 @@ class Http2Layer(base.Layer):
|
|||||||
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
||||||
self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers)
|
self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers)
|
||||||
self.streams[eid].timestamp_start = time.time()
|
self.streams[eid].timestamp_start = time.time()
|
||||||
self.streams[eid].no_body = (event.stream_ended is not None)
|
|
||||||
if event.priority_updated is not None:
|
if event.priority_updated is not None:
|
||||||
self.streams[eid].priority_exclusive = event.priority_updated.exclusive
|
self.streams[eid].priority_exclusive = event.priority_updated.exclusive
|
||||||
self.streams[eid].priority_depends_on = event.priority_updated.depends_on
|
self.streams[eid].priority_depends_on = event.priority_updated.depends_on
|
||||||
self.streams[eid].priority_weight = event.priority_updated.weight
|
self.streams[eid].priority_weight = event.priority_updated.weight
|
||||||
self.streams[eid].handled_priority_event = event.priority_updated
|
self.streams[eid].handled_priority_event = event.priority_updated
|
||||||
self.streams[eid].start()
|
self.streams[eid].start()
|
||||||
self.streams[eid].request_arrived.set()
|
self.streams[eid].request_message.arrived.set()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_response_received(self, eid, event):
|
def _handle_response_received(self, eid, event):
|
||||||
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
||||||
self.streams[eid].queued_data_length = 0
|
self.streams[eid].queued_data_length = 0
|
||||||
self.streams[eid].timestamp_start = time.time()
|
self.streams[eid].timestamp_start = time.time()
|
||||||
self.streams[eid].response_headers = headers
|
self.streams[eid].response_message.headers = headers
|
||||||
self.streams[eid].response_arrived.set()
|
self.streams[eid].response_message.arrived.set()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_data_received(self, eid, event, source_conn):
|
def _handle_data_received(self, eid, event, source_conn):
|
||||||
@ -219,7 +219,7 @@ class Http2Layer(base.Layer):
|
|||||||
|
|
||||||
def _handle_stream_ended(self, eid):
|
def _handle_stream_ended(self, eid):
|
||||||
self.streams[eid].timestamp_end = time.time()
|
self.streams[eid].timestamp_end = time.time()
|
||||||
self.streams[eid].data_finished.set()
|
self.streams[eid].stream_ended.set()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _handle_stream_reset(self, eid, event, is_server, other_conn):
|
def _handle_stream_reset(self, eid, event, is_server, other_conn):
|
||||||
@ -233,6 +233,11 @@ class Http2Layer(base.Layer):
|
|||||||
self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code)
|
self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _handle_trailers(self, eid, event, is_server, other_conn):
|
||||||
|
trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
||||||
|
self.streams[eid].trailers = trailers
|
||||||
|
return True
|
||||||
|
|
||||||
def _handle_remote_settings_changed(self, event, other_conn):
|
def _handle_remote_settings_changed(self, event, other_conn):
|
||||||
new_settings = dict([(key, cs.new_value) for (key, cs) in event.changed_settings.items()])
|
new_settings = dict([(key, cs.new_value) for (key, cs) in event.changed_settings.items()])
|
||||||
self.connections[other_conn].safe_update_settings(new_settings)
|
self.connections[other_conn].safe_update_settings(new_settings)
|
||||||
@ -277,8 +282,8 @@ class Http2Layer(base.Layer):
|
|||||||
self.streams[event.pushed_stream_id].pushed = True
|
self.streams[event.pushed_stream_id].pushed = True
|
||||||
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
|
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
|
||||||
self.streams[event.pushed_stream_id].timestamp_end = time.time()
|
self.streams[event.pushed_stream_id].timestamp_end = time.time()
|
||||||
self.streams[event.pushed_stream_id].request_arrived.set()
|
self.streams[event.pushed_stream_id].request_message.arrived.set()
|
||||||
self.streams[event.pushed_stream_id].request_data_finished.set()
|
self.streams[event.pushed_stream_id].request_message.stream_ended.set()
|
||||||
self.streams[event.pushed_stream_id].start()
|
self.streams[event.pushed_stream_id].start()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -392,6 +397,16 @@ def detect_zombie_stream(func): # pragma: no cover
|
|||||||
|
|
||||||
class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThread):
|
class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThread):
|
||||||
|
|
||||||
|
class Message:
|
||||||
|
def __init__(self, headers=None):
|
||||||
|
self.headers: Optional[mitmproxy.net.http.Headers] = headers # headers are the first thing to be received on a new stream
|
||||||
|
self.data_queue: queue.Queue[bytes] = queue.Queue() # contains raw contents of DATA frames
|
||||||
|
self.queued_data_length = 0 # used to enforce mitmproxy's config.options.body_size_limit
|
||||||
|
self.trailers: Optional[mitmproxy.net.http.Headers] = None # trailers are received after stream_ended is set
|
||||||
|
|
||||||
|
self.arrived = threading.Event() # indicates the HEADERS+CONTINUTATION frames have been received
|
||||||
|
self.stream_ended = threading.Event() # indicates the a frame with the END_STREAM flag has been received
|
||||||
|
|
||||||
def __init__(self, ctx, h2_connection, stream_id: int, request_headers: mitmproxy.net.http.Headers) -> None:
|
def __init__(self, ctx, h2_connection, stream_id: int, request_headers: mitmproxy.net.http.Headers) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
|
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
|
||||||
@ -400,24 +415,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
self.zombie: Optional[float] = None
|
self.zombie: Optional[float] = None
|
||||||
self.client_stream_id: int = stream_id
|
self.client_stream_id: int = stream_id
|
||||||
self.server_stream_id: Optional[int] = None
|
self.server_stream_id: Optional[int] = None
|
||||||
self.request_headers = request_headers
|
|
||||||
self.response_headers: Optional[mitmproxy.net.http.Headers] = None
|
|
||||||
self.pushed = False
|
self.pushed = False
|
||||||
|
|
||||||
self.timestamp_start: Optional[float] = None
|
self.timestamp_start: Optional[float] = None
|
||||||
self.timestamp_end: Optional[float] = None
|
self.timestamp_end: Optional[float] = None
|
||||||
|
|
||||||
self.request_arrived = threading.Event()
|
self.request_message = self.Message(request_headers)
|
||||||
self.request_data_queue: queue.Queue[bytes] = queue.Queue()
|
self.response_message = self.Message()
|
||||||
self.request_queued_data_length = 0
|
|
||||||
self.request_data_finished = threading.Event()
|
|
||||||
|
|
||||||
self.response_arrived = threading.Event()
|
|
||||||
self.response_data_queue: queue.Queue[bytes] = queue.Queue()
|
|
||||||
self.response_queued_data_length = 0
|
|
||||||
self.response_data_finished = threading.Event()
|
|
||||||
|
|
||||||
self.no_body = False
|
|
||||||
|
|
||||||
self.priority_exclusive: bool
|
self.priority_exclusive: bool
|
||||||
self.priority_depends_on: Optional[int] = None
|
self.priority_depends_on: Optional[int] = None
|
||||||
@ -427,10 +431,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
def kill(self):
|
def kill(self):
|
||||||
if not self.zombie:
|
if not self.zombie:
|
||||||
self.zombie = time.time()
|
self.zombie = time.time()
|
||||||
self.request_data_finished.set()
|
self.request_message.stream_ended.set()
|
||||||
self.request_arrived.set()
|
self.request_message.arrived.set()
|
||||||
self.response_arrived.set()
|
self.response_message.arrived.set()
|
||||||
self.response_data_finished.set()
|
self.response_message.stream_ended.set()
|
||||||
|
|
||||||
def connect(self): # pragma: no cover
|
def connect(self): # pragma: no cover
|
||||||
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
|
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
|
||||||
@ -448,28 +452,44 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def data_queue(self):
|
def data_queue(self):
|
||||||
if self.response_arrived.is_set():
|
if self.response_message.arrived.is_set():
|
||||||
return self.response_data_queue
|
return self.response_message.data_queue
|
||||||
else:
|
else:
|
||||||
return self.request_data_queue
|
return self.request_message.data_queue
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def queued_data_length(self):
|
def queued_data_length(self):
|
||||||
if self.response_arrived.is_set():
|
if self.response_message.arrived.is_set():
|
||||||
return self.response_queued_data_length
|
return self.response_message.queued_data_length
|
||||||
else:
|
else:
|
||||||
return self.request_queued_data_length
|
return self.request_message.queued_data_length
|
||||||
|
|
||||||
@queued_data_length.setter
|
@queued_data_length.setter
|
||||||
def queued_data_length(self, v):
|
def queued_data_length(self, v):
|
||||||
self.request_queued_data_length = v
|
self.request_message.queued_data_length = v
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_finished(self):
|
def stream_ended(self):
|
||||||
if self.response_arrived.is_set():
|
# This indicates that all message headers, the full message body, and all trailers have been received
|
||||||
return self.response_data_finished
|
# https://tools.ietf.org/html/rfc7540#section-8.1
|
||||||
|
if self.response_message.arrived.is_set():
|
||||||
|
return self.response_message.stream_ended
|
||||||
else:
|
else:
|
||||||
return self.request_data_finished
|
return self.request_message.stream_ended
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trailers(self):
|
||||||
|
if self.response_message.arrived.is_set():
|
||||||
|
return self.response_message.trailers
|
||||||
|
else:
|
||||||
|
return self.request_message.trailers
|
||||||
|
|
||||||
|
@trailers.setter
|
||||||
|
def trailers(self, v):
|
||||||
|
if self.response_message.arrived.is_set():
|
||||||
|
self.response_message.trailers = v
|
||||||
|
else:
|
||||||
|
self.request_message.trailers = v
|
||||||
|
|
||||||
def raise_zombie(self, pre_command=None): # pragma: no cover
|
def raise_zombie(self, pre_command=None): # pragma: no cover
|
||||||
connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED
|
connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED
|
||||||
@ -480,13 +500,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def read_request_headers(self, flow):
|
def read_request_headers(self, flow):
|
||||||
self.request_arrived.wait()
|
self.request_message.arrived.wait()
|
||||||
self.raise_zombie()
|
self.raise_zombie()
|
||||||
|
|
||||||
if self.pushed:
|
if self.pushed:
|
||||||
flow.metadata['h2-pushed-stream'] = True
|
flow.metadata['h2-pushed-stream'] = True
|
||||||
|
|
||||||
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers)
|
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_message.headers)
|
||||||
return http.HTTPRequest(
|
return http.HTTPRequest(
|
||||||
first_line_format,
|
first_line_format,
|
||||||
method,
|
method,
|
||||||
@ -495,7 +515,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
port,
|
port,
|
||||||
path,
|
path,
|
||||||
b"HTTP/2.0",
|
b"HTTP/2.0",
|
||||||
self.request_headers,
|
self.request_message.headers,
|
||||||
None,
|
None,
|
||||||
timestamp_start=self.timestamp_start,
|
timestamp_start=self.timestamp_start,
|
||||||
timestamp_end=self.timestamp_end,
|
timestamp_end=self.timestamp_end,
|
||||||
@ -504,20 +524,24 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def read_request_body(self, request):
|
def read_request_body(self, request):
|
||||||
if not request.stream:
|
if not request.stream:
|
||||||
self.request_data_finished.wait()
|
self.request_message.stream_ended.wait()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
yield self.request_data_queue.get(timeout=0.1)
|
yield self.request_message.data_queue.get(timeout=0.1)
|
||||||
except queue.Empty: # pragma: no cover
|
except queue.Empty: # pragma: no cover
|
||||||
pass
|
pass
|
||||||
if self.request_data_finished.is_set():
|
if self.request_message.stream_ended.is_set():
|
||||||
self.raise_zombie()
|
self.raise_zombie()
|
||||||
while self.request_data_queue.qsize() > 0:
|
while self.request_message.data_queue.qsize() > 0:
|
||||||
yield self.request_data_queue.get()
|
yield self.request_message.data_queue.get()
|
||||||
break
|
break
|
||||||
self.raise_zombie()
|
self.raise_zombie()
|
||||||
|
|
||||||
|
@detect_zombie_stream
|
||||||
|
def read_request_trailers(self, request):
|
||||||
|
return self.request_message.trailers
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def send_request_headers(self, request):
|
def send_request_headers(self, request):
|
||||||
if self.pushed:
|
if self.pushed:
|
||||||
@ -567,7 +591,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
self.raise_zombie,
|
self.raise_zombie,
|
||||||
self.server_stream_id,
|
self.server_stream_id,
|
||||||
headers,
|
headers,
|
||||||
end_stream=self.no_body,
|
|
||||||
priority_exclusive=priority_exclusive,
|
priority_exclusive=priority_exclusive,
|
||||||
priority_depends_on=priority_depends_on,
|
priority_depends_on=priority_depends_on,
|
||||||
priority_weight=priority_weight,
|
priority_weight=priority_weight,
|
||||||
@ -584,26 +607,31 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
# nothing to do here
|
# nothing to do here
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.no_body:
|
|
||||||
self.connections[self.server_conn].safe_send_body(
|
self.connections[self.server_conn].safe_send_body(
|
||||||
self.raise_zombie,
|
self.raise_zombie,
|
||||||
self.server_stream_id,
|
self.server_stream_id,
|
||||||
chunks
|
chunks,
|
||||||
|
end_stream=(request.trailers is None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def send_request(self, message):
|
def send_request_trailers(self, request):
|
||||||
self.send_request_headers(message)
|
self._send_trailers(self.server_conn, request.trailers)
|
||||||
self.send_request_body(message, [message.content])
|
|
||||||
|
@detect_zombie_stream
|
||||||
|
def send_request(self, request):
|
||||||
|
self.send_request_headers(request)
|
||||||
|
self.send_request_body(request, [request.content])
|
||||||
|
self.send_request_trailers(request)
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def read_response_headers(self):
|
def read_response_headers(self):
|
||||||
self.response_arrived.wait()
|
self.response_message.arrived.wait()
|
||||||
|
|
||||||
self.raise_zombie()
|
self.raise_zombie()
|
||||||
|
|
||||||
status_code = int(self.response_headers.get(':status', 502))
|
status_code = int(self.response_message.headers.get(':status', 502))
|
||||||
headers = self.response_headers.copy()
|
headers = self.response_message.headers.copy()
|
||||||
headers.pop(":status", None)
|
headers.pop(":status", None)
|
||||||
|
|
||||||
return http.HTTPResponse(
|
return http.HTTPResponse(
|
||||||
@ -620,16 +648,20 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
def read_response_body(self, request, response):
|
def read_response_body(self, request, response):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
yield self.response_data_queue.get(timeout=0.1)
|
yield self.response_message.data_queue.get(timeout=0.1)
|
||||||
except queue.Empty: # pragma: no cover
|
except queue.Empty: # pragma: no cover
|
||||||
pass
|
pass
|
||||||
if self.response_data_finished.is_set():
|
if self.response_message.stream_ended.is_set():
|
||||||
self.raise_zombie()
|
self.raise_zombie()
|
||||||
while self.response_data_queue.qsize() > 0:
|
while self.response_message.data_queue.qsize() > 0:
|
||||||
yield self.response_data_queue.get()
|
yield self.response_message.data_queue.get()
|
||||||
break
|
break
|
||||||
self.raise_zombie()
|
self.raise_zombie()
|
||||||
|
|
||||||
|
@detect_zombie_stream
|
||||||
|
def read_response_trailers(self, request, response):
|
||||||
|
return self.response_message.trailers
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def send_response_headers(self, response):
|
def send_response_headers(self, response):
|
||||||
headers = response.headers.copy()
|
headers = response.headers.copy()
|
||||||
@ -642,11 +674,27 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
|||||||
)
|
)
|
||||||
|
|
||||||
@detect_zombie_stream
|
@detect_zombie_stream
|
||||||
def send_response_body(self, _response, chunks):
|
def send_response_body(self, response, chunks):
|
||||||
self.connections[self.client_conn].safe_send_body(
|
self.connections[self.client_conn].safe_send_body(
|
||||||
self.raise_zombie,
|
self.raise_zombie,
|
||||||
self.client_stream_id,
|
self.client_stream_id,
|
||||||
chunks
|
chunks,
|
||||||
|
end_stream=(response.trailers is None),
|
||||||
|
)
|
||||||
|
|
||||||
|
@detect_zombie_stream
|
||||||
|
def send_response_trailers(self, response):
|
||||||
|
self._send_trailers(self.client_conn, response.trailers)
|
||||||
|
|
||||||
|
def _send_trailers(self, conn, trailers):
|
||||||
|
if not trailers:
|
||||||
|
return
|
||||||
|
with self.connections[conn].lock:
|
||||||
|
self.connections[conn].safe_send_headers(
|
||||||
|
self.raise_zombie,
|
||||||
|
self.client_stream_id,
|
||||||
|
trailers,
|
||||||
|
end_stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self): # pragma: no cover
|
def __call__(self): # pragma: no cover
|
||||||
|
@ -93,6 +93,9 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
|||||||
"timestamp_end": flow.response.timestamp_end,
|
"timestamp_end": flow.response.timestamp_end,
|
||||||
"is_replay": flow.response.is_replay,
|
"is_replay": flow.response.is_replay,
|
||||||
}
|
}
|
||||||
|
if flow.response.data.trailers:
|
||||||
|
f["response"]["trailers"] = tuple(flow.response.data.trailers.items(True))
|
||||||
|
|
||||||
f.get("server_conn", {}).pop("cert", None)
|
f.get("server_conn", {}).pop("cert", None)
|
||||||
f.get("client_conn", {}).pop("mitmcert", None)
|
f.get("client_conn", {}).pop("mitmcert", None)
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ MITMPROXY = "mitmproxy " + VERSION
|
|||||||
|
|
||||||
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
|
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
|
||||||
# for each change in the file format.
|
# for each change in the file format.
|
||||||
FLOW_FORMAT_VERSION = 7
|
FLOW_FORMAT_VERSION = 8
|
||||||
|
|
||||||
|
|
||||||
def get_dev_version() -> str:
|
def get_dev_version() -> str:
|
||||||
|
@ -110,8 +110,9 @@ class HTTP2StateProtocol:
|
|||||||
b"HTTP/2.0",
|
b"HTTP/2.0",
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
timestamp_start,
|
None,
|
||||||
timestamp_end,
|
timestamp_start=timestamp_start,
|
||||||
|
timestamp_end=timestamp_end,
|
||||||
)
|
)
|
||||||
request.stream_id = stream_id
|
request.stream_id = stream_id
|
||||||
|
|
||||||
|
@ -21,8 +21,11 @@ class TestRequestData:
|
|||||||
treq(headers="foobar")
|
treq(headers="foobar")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
treq(content="foobar")
|
treq(content="foobar")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
treq(trailers="foobar")
|
||||||
|
|
||||||
assert isinstance(treq(headers=()).headers, Headers)
|
assert isinstance(treq(headers=()).headers, Headers)
|
||||||
|
assert isinstance(treq(trailers=()).trailers, Headers)
|
||||||
|
|
||||||
|
|
||||||
class TestRequestCore:
|
class TestRequestCore:
|
||||||
|
@ -20,8 +20,11 @@ class TestResponseData:
|
|||||||
tresp(reason="fööbär")
|
tresp(reason="fööbär")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
tresp(content="foobar")
|
tresp(content="foobar")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
tresp(trailers="foobar")
|
||||||
|
|
||||||
assert isinstance(tresp(headers=()).headers, Headers)
|
assert isinstance(tresp(headers=()).headers, Headers)
|
||||||
|
assert isinstance(tresp(trailers=()).trailers, Headers)
|
||||||
|
|
||||||
|
|
||||||
class TestResponseCore:
|
class TestResponseCore:
|
||||||
|
@ -1031,3 +1031,147 @@ class TestResponseStreaming(_Http2Test):
|
|||||||
assert data
|
assert data
|
||||||
else:
|
else:
|
||||||
assert data is None
|
assert data is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestTrailers(_Http2Test):
|
||||||
|
server_trailers_received = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_server_event(cls, event, h2_conn, rfile, wfile):
|
||||||
|
if isinstance(event, h2.events.RequestReceived):
|
||||||
|
# reset the value for a fresh test
|
||||||
|
cls.server_trailers_received = False
|
||||||
|
elif isinstance(event, h2.events.ConnectionTerminated):
|
||||||
|
return False
|
||||||
|
elif isinstance(event, h2.events.TrailersReceived):
|
||||||
|
cls.server_trailers_received = True
|
||||||
|
|
||||||
|
elif isinstance(event, h2.events.StreamEnded):
|
||||||
|
h2_conn.send_headers(event.stream_id, [
|
||||||
|
(':status', '200'),
|
||||||
|
('x-my-trailer-request-received', 'success' if cls.server_trailers_received else "failure"),
|
||||||
|
], end_stream=True)
|
||||||
|
wfile.write(h2_conn.data_to_send())
|
||||||
|
wfile.flush()
|
||||||
|
return True
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('announce', [True, False])
|
||||||
|
@pytest.mark.parametrize('body', [None, b"foobar"])
|
||||||
|
def test_trailers(self, announce, body):
|
||||||
|
h2_conn = self.setup_connection()
|
||||||
|
stream_id = 1
|
||||||
|
headers = [
|
||||||
|
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
|
||||||
|
(':method', 'GET'),
|
||||||
|
(':scheme', 'https'),
|
||||||
|
(':path', '/'),
|
||||||
|
]
|
||||||
|
if announce:
|
||||||
|
headers.append(('trailer', 'x-my-trailers'))
|
||||||
|
h2_conn.send_headers(
|
||||||
|
stream_id=stream_id,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
if body:
|
||||||
|
h2_conn.send_data(stream_id, body)
|
||||||
|
|
||||||
|
# send trailers
|
||||||
|
h2_conn.send_headers(stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
|
||||||
|
|
||||||
|
self.client.wfile.write(h2_conn.data_to_send())
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
try:
|
||||||
|
raw = b''.join(http2.read_raw_frame(self.client.rfile))
|
||||||
|
events = h2_conn.receive_data(raw)
|
||||||
|
except exceptions.HttpException:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
assert False
|
||||||
|
|
||||||
|
self.client.wfile.write(h2_conn.data_to_send())
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
if isinstance(event, h2.events.StreamEnded):
|
||||||
|
done = True
|
||||||
|
|
||||||
|
h2_conn.close_connection()
|
||||||
|
self.client.wfile.write(h2_conn.data_to_send())
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
assert len(self.master.state.flows) == 1
|
||||||
|
assert self.master.state.flows[0].request.trailers['x-my-trailers'] == 'foobar'
|
||||||
|
assert self.master.state.flows[0].response.status_code == 200
|
||||||
|
assert self.master.state.flows[0].response.headers['x-my-trailer-request-received'] == 'success'
|
||||||
|
|
||||||
|
|
||||||
|
class TestResponseTrailers(_Http2Test):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_server_event(cls, event, h2_conn, rfile, wfile):
|
||||||
|
if isinstance(event, h2.events.ConnectionTerminated):
|
||||||
|
return False
|
||||||
|
elif isinstance(event, h2.events.StreamEnded):
|
||||||
|
headers = [
|
||||||
|
(':status', '200'),
|
||||||
|
]
|
||||||
|
if event.stream_id == 1:
|
||||||
|
# special stream_id to activate the Trailer announcement header
|
||||||
|
headers.append(('trailer', 'x-my-trailers'))
|
||||||
|
|
||||||
|
h2_conn.send_headers(event.stream_id, headers)
|
||||||
|
h2_conn.send_data(event.stream_id, b'response body')
|
||||||
|
h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
|
||||||
|
wfile.write(h2_conn.data_to_send())
|
||||||
|
wfile.flush()
|
||||||
|
return True
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('announce', [True, False])
|
||||||
|
def test_trailers(self, announce):
|
||||||
|
response_body_buffer = b''
|
||||||
|
h2_conn = self.setup_connection()
|
||||||
|
|
||||||
|
self._send_request(
|
||||||
|
self.client.wfile,
|
||||||
|
h2_conn,
|
||||||
|
stream_id=(1 if announce else 3),
|
||||||
|
headers=[
|
||||||
|
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
|
||||||
|
(':method', 'GET'),
|
||||||
|
(':scheme', 'https'),
|
||||||
|
(':path', '/'),
|
||||||
|
])
|
||||||
|
|
||||||
|
trailers_buffer = None
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
try:
|
||||||
|
raw = b''.join(http2.read_raw_frame(self.client.rfile))
|
||||||
|
events = h2_conn.receive_data(raw)
|
||||||
|
except exceptions.HttpException:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
assert False
|
||||||
|
|
||||||
|
self.client.wfile.write(h2_conn.data_to_send())
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
if isinstance(event, h2.events.DataReceived):
|
||||||
|
response_body_buffer += event.data
|
||||||
|
elif isinstance(event, h2.events.TrailersReceived):
|
||||||
|
trailers_buffer = event.headers
|
||||||
|
elif isinstance(event, h2.events.StreamEnded):
|
||||||
|
done = True
|
||||||
|
|
||||||
|
h2_conn.close_connection()
|
||||||
|
self.client.wfile.write(h2_conn.data_to_send())
|
||||||
|
self.client.wfile.flush()
|
||||||
|
|
||||||
|
assert len(self.master.state.flows) == 1
|
||||||
|
assert self.master.state.flows[0].response.status_code == 200
|
||||||
|
assert self.master.state.flows[0].response.content == b'response body'
|
||||||
|
assert response_body_buffer == b'response body'
|
||||||
|
assert self.master.state.flows[0].response.trailers['x-my-trailers'] == 'foobar'
|
||||||
|
assert trailers_buffer == [(b'x-my-trailers', b'foobar')]
|
||||||
|
Loading…
Reference in New Issue
Block a user