mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
Merge pull request #4042 from sanlengjingvv/develop
support HTTP/2 trailers
This commit is contained in:
commit
46a0f69485
28
examples/addons/http-trailers.py
Normal file
28
examples/addons/http-trailers.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""
|
||||
This script simply prints all received HTTP Trailers.
|
||||
|
||||
HTTP requests and responses can container trailing headers which are sent after
|
||||
the body is fully transmitted. Such trailers need to be announced in the initial
|
||||
headers by name, so the receiving endpoint can wait and read them after the
|
||||
body.
|
||||
"""
|
||||
|
||||
from mitmproxy import http
|
||||
from mitmproxy.net.http import Headers
|
||||
|
||||
|
||||
def request(flow: http.HTTPFlow):
|
||||
if flow.request.trailers:
|
||||
print("HTTP Trailers detected! Request contains:", flow.request.trailers)
|
||||
|
||||
|
||||
def response(flow: http.HTTPFlow):
|
||||
if flow.response.trailers:
|
||||
print("HTTP Trailers detected! Response contains:", flow.response.trailers)
|
||||
|
||||
if flow.request.path == "/inject_trailers":
|
||||
flow.response.headers["trailer"] = "x-my-injected-trailer-header"
|
||||
flow.response.trailers = Headers([
|
||||
(b"x-my-injected-trailer-header", b"foobar")
|
||||
])
|
||||
print("Injected a new trailer...", flow.response.headers["trailer"])
|
@ -26,6 +26,7 @@ class HTTPRequest(http.Request):
|
||||
http_version,
|
||||
headers,
|
||||
content,
|
||||
trailers=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
is_replay=False,
|
||||
@ -41,6 +42,7 @@ class HTTPRequest(http.Request):
|
||||
http_version,
|
||||
headers,
|
||||
content,
|
||||
trailers,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
)
|
||||
@ -73,6 +75,7 @@ class HTTPRequest(http.Request):
|
||||
http_version=request.data.http_version,
|
||||
headers=request.data.headers,
|
||||
content=request.data.content,
|
||||
trailers=request.data.trailers,
|
||||
timestamp_start=request.data.timestamp_start,
|
||||
timestamp_end=request.data.timestamp_end,
|
||||
)
|
||||
@ -97,6 +100,7 @@ class HTTPResponse(http.Response):
|
||||
reason,
|
||||
headers,
|
||||
content,
|
||||
trailers=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
is_replay=False
|
||||
@ -108,6 +112,7 @@ class HTTPResponse(http.Response):
|
||||
reason,
|
||||
headers,
|
||||
content,
|
||||
trailers,
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
)
|
||||
@ -127,6 +132,7 @@ class HTTPResponse(http.Response):
|
||||
reason=response.data.reason,
|
||||
headers=response.data.headers,
|
||||
content=response.data.content,
|
||||
trailers=response.data.trailers,
|
||||
timestamp_start=response.data.timestamp_start,
|
||||
timestamp_end=response.data.timestamp_end,
|
||||
)
|
||||
|
@ -172,6 +172,13 @@ def convert_6_7(data):
|
||||
return data
|
||||
|
||||
|
||||
def convert_7_8(data):
|
||||
data["version"] = 8
|
||||
data["request"]["trailers"] = None
|
||||
data["response"]["trailers"] = None
|
||||
return data
|
||||
|
||||
|
||||
def _convert_dict_keys(o: Any) -> Any:
|
||||
if isinstance(o, dict):
|
||||
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
|
||||
@ -226,6 +233,7 @@ converters = {
|
||||
4: convert_4_5,
|
||||
5: convert_5_6,
|
||||
6: convert_6_7,
|
||||
7: convert_7_8,
|
||||
}
|
||||
|
||||
|
||||
|
@ -59,7 +59,7 @@ def read_request_head(rfile):
|
||||
timestamp_start = rfile.first_byte_timestamp
|
||||
|
||||
return request.Request(
|
||||
form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
|
||||
form, method, scheme, host, port, path, http_version, headers, None, None, timestamp_start
|
||||
)
|
||||
|
||||
|
||||
@ -98,7 +98,7 @@ def read_response_head(rfile):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = rfile.first_byte_timestamp
|
||||
|
||||
return response.Response(http_version, status_code, message, headers, None, timestamp_start)
|
||||
return response.Response(http_version, status_code, message, headers, None, None, timestamp_start)
|
||||
|
||||
|
||||
def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
|
||||
|
@ -28,6 +28,8 @@ class MessageData(serializable.Serializable):
|
||||
def get_state(self):
|
||||
state = vars(self).copy()
|
||||
state["headers"] = state["headers"].get_state()
|
||||
if 'trailers' in state and state["trailers"] is not None:
|
||||
state["trailers"] = state["trailers"].get_state()
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
@ -53,6 +55,8 @@ class Message(serializable.Serializable):
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
state["headers"] = mheaders.Headers.from_state(state["headers"])
|
||||
if 'trailers' in state and state["trailers"] is not None:
|
||||
state["trailers"] = mheaders.Headers.from_state(state["trailers"])
|
||||
return cls(**state)
|
||||
|
||||
@property
|
||||
@ -130,6 +134,20 @@ class Message(serializable.Serializable):
|
||||
|
||||
content = property(get_content, set_content)
|
||||
|
||||
@property
|
||||
def trailers(self):
|
||||
"""
|
||||
Message trailers object
|
||||
|
||||
Returns:
|
||||
mitmproxy.net.http.Headers
|
||||
"""
|
||||
return self.data.trailers
|
||||
|
||||
@trailers.setter
|
||||
def trailers(self, h):
|
||||
self.data.trailers = h
|
||||
|
||||
@property
|
||||
def http_version(self):
|
||||
"""
|
||||
|
@ -29,6 +29,7 @@ class RequestData(message.MessageData):
|
||||
http_version,
|
||||
headers=(),
|
||||
content=None,
|
||||
trailers=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None
|
||||
):
|
||||
@ -46,6 +47,8 @@ class RequestData(message.MessageData):
|
||||
headers = nheaders.Headers(headers)
|
||||
if isinstance(content, str):
|
||||
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
||||
if trailers is not None and not isinstance(trailers, nheaders.Headers):
|
||||
trailers = nheaders.Headers(trailers)
|
||||
|
||||
self.first_line_format = first_line_format
|
||||
self.method = method
|
||||
@ -56,6 +59,7 @@ class RequestData(message.MessageData):
|
||||
self.http_version = http_version
|
||||
self.headers = headers
|
||||
self.content = content
|
||||
self.trailers = trailers
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
|
@ -22,6 +22,7 @@ class ResponseData(message.MessageData):
|
||||
reason=None,
|
||||
headers=(),
|
||||
content=None,
|
||||
trailers=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None
|
||||
):
|
||||
@ -33,12 +34,15 @@ class ResponseData(message.MessageData):
|
||||
headers = nheaders.Headers(headers)
|
||||
if isinstance(content, str):
|
||||
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
||||
if trailers is not None and not isinstance(trailers, nheaders.Headers):
|
||||
trailers = nheaders.Headers(trailers)
|
||||
|
||||
self.http_version = http_version
|
||||
self.status_code = status_code
|
||||
self.reason = reason
|
||||
self.headers = headers
|
||||
self.content = content
|
||||
self.trailers = trailers
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
|
@ -20,6 +20,9 @@ class _HttpTransmissionLayer(base.Layer):
|
||||
def read_request_body(self, request):
|
||||
raise NotImplementedError()
|
||||
|
||||
def read_request_trailers(self, request):
|
||||
raise NotImplementedError()
|
||||
|
||||
def send_request(self, request):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -30,11 +33,15 @@ class _HttpTransmissionLayer(base.Layer):
|
||||
raise NotImplementedError()
|
||||
yield "this is a generator" # pragma: no cover
|
||||
|
||||
def read_response_trailers(self, request, response):
|
||||
raise NotImplementedError()
|
||||
|
||||
def read_response(self, request):
|
||||
response = self.read_response_headers()
|
||||
response.data.content = b"".join(
|
||||
self.read_response_body(request, response)
|
||||
)
|
||||
response.data.trailers = self.read_response_trailers(request, response)
|
||||
return response
|
||||
|
||||
def send_response(self, response):
|
||||
@ -42,6 +49,7 @@ class _HttpTransmissionLayer(base.Layer):
|
||||
raise exceptions.HttpException("Cannot assemble flow with missing content")
|
||||
self.send_response_headers(response)
|
||||
self.send_response_body(response, [response.data.content])
|
||||
self.send_response_trailers(response)
|
||||
|
||||
def send_response_headers(self, response):
|
||||
raise NotImplementedError()
|
||||
@ -49,6 +57,9 @@ class _HttpTransmissionLayer(base.Layer):
|
||||
def send_response_body(self, response, chunks):
|
||||
raise NotImplementedError()
|
||||
|
||||
def send_response_trailers(self, response, chunks):
|
||||
raise NotImplementedError()
|
||||
|
||||
def check_close_connection(self, f):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -255,6 +266,7 @@ class HttpLayer(base.Layer):
|
||||
f.request.data.content = b"".join(
|
||||
self.read_request_body(f.request)
|
||||
)
|
||||
f.request.data.trailers = self.read_request_trailers(f.request)
|
||||
f.request.timestamp_end = time.time()
|
||||
self.channel.ask("http_connect", f)
|
||||
|
||||
@ -282,6 +294,9 @@ class HttpLayer(base.Layer):
|
||||
f.request.data.content = None
|
||||
else:
|
||||
f.request.data.content = b"".join(self.read_request_body(request))
|
||||
|
||||
f.request.data.trailers = self.read_request_trailers(f.request)
|
||||
|
||||
request.timestamp_end = time.time()
|
||||
except exceptions.HttpException as e:
|
||||
# We optimistically guess there might be an HTTP client on the
|
||||
@ -348,6 +363,8 @@ class HttpLayer(base.Layer):
|
||||
else:
|
||||
self.send_request_body(f.request, [f.request.data.content])
|
||||
|
||||
self.send_request_trailers(f.request)
|
||||
|
||||
f.response = self.read_response_headers()
|
||||
|
||||
try:
|
||||
@ -406,6 +423,8 @@ class HttpLayer(base.Layer):
|
||||
# we now need to emulate the responseheaders hook.
|
||||
self.channel.ask("responseheaders", f)
|
||||
|
||||
f.response.data.trailers = self.read_response_trailers(f.request, f.response)
|
||||
|
||||
self.log("response", "debug", [repr(f.response)])
|
||||
self.channel.ask("response", f)
|
||||
|
||||
|
@ -23,6 +23,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
||||
human.parse_size(self.config.options.body_size_limit)
|
||||
)
|
||||
|
||||
def read_request_trailers(self, request):
|
||||
if "Trailer" in request.headers:
|
||||
# TODO: not implemented yet
|
||||
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
|
||||
return None
|
||||
|
||||
def send_request_headers(self, request):
|
||||
headers = http1.assemble_request_head(request)
|
||||
self.server_conn.wfile.write(headers)
|
||||
@ -33,7 +39,13 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
||||
self.server_conn.wfile.write(chunk)
|
||||
self.server_conn.wfile.flush()
|
||||
|
||||
def send_request_trailers(self, request):
|
||||
if "Trailer" in request.headers:
|
||||
# TODO: not implemented yet
|
||||
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
|
||||
|
||||
def send_request(self, request):
|
||||
# TODO: this does not yet support request trailers
|
||||
self.server_conn.wfile.write(http1.assemble_request(request))
|
||||
self.server_conn.wfile.flush()
|
||||
|
||||
@ -49,6 +61,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
||||
human.parse_size(self.config.options.body_size_limit)
|
||||
)
|
||||
|
||||
def read_response_trailers(self, request, response):
|
||||
if "Trailer" in response.headers:
|
||||
# TODO: not implemented yet
|
||||
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
|
||||
return None
|
||||
|
||||
def send_response_headers(self, response):
|
||||
raw = http1.assemble_response_head(response)
|
||||
self.client_conn.wfile.write(raw)
|
||||
@ -59,6 +77,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
||||
self.client_conn.wfile.write(chunk)
|
||||
self.client_conn.wfile.flush()
|
||||
|
||||
def send_response_trailers(self, response):
|
||||
if "Trailer" in response.headers:
|
||||
# TODO: not implemented yet
|
||||
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
|
||||
return
|
||||
|
||||
def check_close_connection(self, flow):
|
||||
request_close = http1.connection_close(
|
||||
flow.request.http_version,
|
||||
|
@ -55,7 +55,7 @@ class SafeH2Connection(connection.H2Connection):
|
||||
self.send_headers(stream_id, headers.fields, **kwargs)
|
||||
self.conn.send(self.data_to_send())
|
||||
|
||||
def safe_send_body(self, raise_zombie: Callable, stream_id: int, chunks: List[bytes]):
|
||||
def safe_send_body(self, raise_zombie: Callable, stream_id: int, chunks: List[bytes], end_stream=True):
|
||||
for chunk in chunks:
|
||||
position = 0
|
||||
while position < len(chunk):
|
||||
@ -75,10 +75,11 @@ class SafeH2Connection(connection.H2Connection):
|
||||
finally:
|
||||
self.lock.release()
|
||||
position += max_outbound_frame_size
|
||||
with self.lock:
|
||||
raise_zombie()
|
||||
self.end_stream(stream_id)
|
||||
self.conn.send(self.data_to_send())
|
||||
if end_stream:
|
||||
with self.lock:
|
||||
raise_zombie()
|
||||
self.end_stream(stream_id)
|
||||
self.conn.send(self.data_to_send())
|
||||
|
||||
|
||||
class Http2Layer(base.Layer):
|
||||
@ -170,7 +171,7 @@ class Http2Layer(base.Layer):
|
||||
elif isinstance(event, events.PriorityUpdated):
|
||||
return self._handle_priority_updated(eid, event)
|
||||
elif isinstance(event, events.TrailersReceived):
|
||||
raise NotImplementedError('TrailersReceived not implemented')
|
||||
return self._handle_trailers(eid, event, is_server, other_conn)
|
||||
|
||||
# fail-safe for unhandled events
|
||||
return True
|
||||
@ -179,22 +180,21 @@ class Http2Layer(base.Layer):
|
||||
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
||||
self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers)
|
||||
self.streams[eid].timestamp_start = time.time()
|
||||
self.streams[eid].no_body = (event.stream_ended is not None)
|
||||
if event.priority_updated is not None:
|
||||
self.streams[eid].priority_exclusive = event.priority_updated.exclusive
|
||||
self.streams[eid].priority_depends_on = event.priority_updated.depends_on
|
||||
self.streams[eid].priority_weight = event.priority_updated.weight
|
||||
self.streams[eid].handled_priority_event = event.priority_updated
|
||||
self.streams[eid].start()
|
||||
self.streams[eid].request_arrived.set()
|
||||
self.streams[eid].request_message.arrived.set()
|
||||
return True
|
||||
|
||||
def _handle_response_received(self, eid, event):
|
||||
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
||||
self.streams[eid].queued_data_length = 0
|
||||
self.streams[eid].timestamp_start = time.time()
|
||||
self.streams[eid].response_headers = headers
|
||||
self.streams[eid].response_arrived.set()
|
||||
self.streams[eid].response_message.headers = headers
|
||||
self.streams[eid].response_message.arrived.set()
|
||||
return True
|
||||
|
||||
def _handle_data_received(self, eid, event, source_conn):
|
||||
@ -219,7 +219,7 @@ class Http2Layer(base.Layer):
|
||||
|
||||
def _handle_stream_ended(self, eid):
|
||||
self.streams[eid].timestamp_end = time.time()
|
||||
self.streams[eid].data_finished.set()
|
||||
self.streams[eid].stream_ended.set()
|
||||
return True
|
||||
|
||||
def _handle_stream_reset(self, eid, event, is_server, other_conn):
|
||||
@ -233,6 +233,11 @@ class Http2Layer(base.Layer):
|
||||
self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code)
|
||||
return True
|
||||
|
||||
def _handle_trailers(self, eid, event, is_server, other_conn):
|
||||
trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
||||
self.streams[eid].trailers = trailers
|
||||
return True
|
||||
|
||||
def _handle_remote_settings_changed(self, event, other_conn):
|
||||
new_settings = dict([(key, cs.new_value) for (key, cs) in event.changed_settings.items()])
|
||||
self.connections[other_conn].safe_update_settings(new_settings)
|
||||
@ -277,8 +282,8 @@ class Http2Layer(base.Layer):
|
||||
self.streams[event.pushed_stream_id].pushed = True
|
||||
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
|
||||
self.streams[event.pushed_stream_id].timestamp_end = time.time()
|
||||
self.streams[event.pushed_stream_id].request_arrived.set()
|
||||
self.streams[event.pushed_stream_id].request_data_finished.set()
|
||||
self.streams[event.pushed_stream_id].request_message.arrived.set()
|
||||
self.streams[event.pushed_stream_id].request_message.stream_ended.set()
|
||||
self.streams[event.pushed_stream_id].start()
|
||||
return True
|
||||
|
||||
@ -392,6 +397,16 @@ def detect_zombie_stream(func): # pragma: no cover
|
||||
|
||||
class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThread):
|
||||
|
||||
class Message:
|
||||
def __init__(self, headers=None):
|
||||
self.headers: Optional[mitmproxy.net.http.Headers] = headers # headers are the first thing to be received on a new stream
|
||||
self.data_queue: queue.Queue[bytes] = queue.Queue() # contains raw contents of DATA frames
|
||||
self.queued_data_length = 0 # used to enforce mitmproxy's config.options.body_size_limit
|
||||
self.trailers: Optional[mitmproxy.net.http.Headers] = None # trailers are received after stream_ended is set
|
||||
|
||||
self.arrived = threading.Event() # indicates the HEADERS+CONTINUTATION frames have been received
|
||||
self.stream_ended = threading.Event() # indicates the a frame with the END_STREAM flag has been received
|
||||
|
||||
def __init__(self, ctx, h2_connection, stream_id: int, request_headers: mitmproxy.net.http.Headers) -> None:
|
||||
super().__init__(
|
||||
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
|
||||
@ -400,24 +415,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
self.zombie: Optional[float] = None
|
||||
self.client_stream_id: int = stream_id
|
||||
self.server_stream_id: Optional[int] = None
|
||||
self.request_headers = request_headers
|
||||
self.response_headers: Optional[mitmproxy.net.http.Headers] = None
|
||||
self.pushed = False
|
||||
|
||||
self.timestamp_start: Optional[float] = None
|
||||
self.timestamp_end: Optional[float] = None
|
||||
|
||||
self.request_arrived = threading.Event()
|
||||
self.request_data_queue: queue.Queue[bytes] = queue.Queue()
|
||||
self.request_queued_data_length = 0
|
||||
self.request_data_finished = threading.Event()
|
||||
|
||||
self.response_arrived = threading.Event()
|
||||
self.response_data_queue: queue.Queue[bytes] = queue.Queue()
|
||||
self.response_queued_data_length = 0
|
||||
self.response_data_finished = threading.Event()
|
||||
|
||||
self.no_body = False
|
||||
self.request_message = self.Message(request_headers)
|
||||
self.response_message = self.Message()
|
||||
|
||||
self.priority_exclusive: bool
|
||||
self.priority_depends_on: Optional[int] = None
|
||||
@ -427,10 +431,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
def kill(self):
|
||||
if not self.zombie:
|
||||
self.zombie = time.time()
|
||||
self.request_data_finished.set()
|
||||
self.request_arrived.set()
|
||||
self.response_arrived.set()
|
||||
self.response_data_finished.set()
|
||||
self.request_message.stream_ended.set()
|
||||
self.request_message.arrived.set()
|
||||
self.response_message.arrived.set()
|
||||
self.response_message.stream_ended.set()
|
||||
|
||||
def connect(self): # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
|
||||
@ -448,28 +452,44 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
|
||||
@property
|
||||
def data_queue(self):
|
||||
if self.response_arrived.is_set():
|
||||
return self.response_data_queue
|
||||
if self.response_message.arrived.is_set():
|
||||
return self.response_message.data_queue
|
||||
else:
|
||||
return self.request_data_queue
|
||||
return self.request_message.data_queue
|
||||
|
||||
@property
|
||||
def queued_data_length(self):
|
||||
if self.response_arrived.is_set():
|
||||
return self.response_queued_data_length
|
||||
if self.response_message.arrived.is_set():
|
||||
return self.response_message.queued_data_length
|
||||
else:
|
||||
return self.request_queued_data_length
|
||||
return self.request_message.queued_data_length
|
||||
|
||||
@queued_data_length.setter
|
||||
def queued_data_length(self, v):
|
||||
self.request_queued_data_length = v
|
||||
self.request_message.queued_data_length = v
|
||||
|
||||
@property
|
||||
def data_finished(self):
|
||||
if self.response_arrived.is_set():
|
||||
return self.response_data_finished
|
||||
def stream_ended(self):
|
||||
# This indicates that all message headers, the full message body, and all trailers have been received
|
||||
# https://tools.ietf.org/html/rfc7540#section-8.1
|
||||
if self.response_message.arrived.is_set():
|
||||
return self.response_message.stream_ended
|
||||
else:
|
||||
return self.request_data_finished
|
||||
return self.request_message.stream_ended
|
||||
|
||||
@property
|
||||
def trailers(self):
|
||||
if self.response_message.arrived.is_set():
|
||||
return self.response_message.trailers
|
||||
else:
|
||||
return self.request_message.trailers
|
||||
|
||||
@trailers.setter
|
||||
def trailers(self, v):
|
||||
if self.response_message.arrived.is_set():
|
||||
self.response_message.trailers = v
|
||||
else:
|
||||
self.request_message.trailers = v
|
||||
|
||||
def raise_zombie(self, pre_command=None): # pragma: no cover
|
||||
connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED
|
||||
@ -480,13 +500,13 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_request_headers(self, flow):
|
||||
self.request_arrived.wait()
|
||||
self.request_message.arrived.wait()
|
||||
self.raise_zombie()
|
||||
|
||||
if self.pushed:
|
||||
flow.metadata['h2-pushed-stream'] = True
|
||||
|
||||
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers)
|
||||
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_message.headers)
|
||||
return http.HTTPRequest(
|
||||
first_line_format,
|
||||
method,
|
||||
@ -495,7 +515,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
port,
|
||||
path,
|
||||
b"HTTP/2.0",
|
||||
self.request_headers,
|
||||
self.request_message.headers,
|
||||
None,
|
||||
timestamp_start=self.timestamp_start,
|
||||
timestamp_end=self.timestamp_end,
|
||||
@ -504,20 +524,24 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
@detect_zombie_stream
|
||||
def read_request_body(self, request):
|
||||
if not request.stream:
|
||||
self.request_data_finished.wait()
|
||||
self.request_message.stream_ended.wait()
|
||||
|
||||
while True:
|
||||
try:
|
||||
yield self.request_data_queue.get(timeout=0.1)
|
||||
yield self.request_message.data_queue.get(timeout=0.1)
|
||||
except queue.Empty: # pragma: no cover
|
||||
pass
|
||||
if self.request_data_finished.is_set():
|
||||
if self.request_message.stream_ended.is_set():
|
||||
self.raise_zombie()
|
||||
while self.request_data_queue.qsize() > 0:
|
||||
yield self.request_data_queue.get()
|
||||
while self.request_message.data_queue.qsize() > 0:
|
||||
yield self.request_message.data_queue.get()
|
||||
break
|
||||
self.raise_zombie()
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_request_trailers(self, request):
|
||||
return self.request_message.trailers
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_request_headers(self, request):
|
||||
if self.pushed:
|
||||
@ -567,7 +591,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
self.raise_zombie,
|
||||
self.server_stream_id,
|
||||
headers,
|
||||
end_stream=self.no_body,
|
||||
priority_exclusive=priority_exclusive,
|
||||
priority_depends_on=priority_depends_on,
|
||||
priority_weight=priority_weight,
|
||||
@ -584,26 +607,31 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
# nothing to do here
|
||||
return
|
||||
|
||||
if not self.no_body:
|
||||
self.connections[self.server_conn].safe_send_body(
|
||||
self.raise_zombie,
|
||||
self.server_stream_id,
|
||||
chunks
|
||||
)
|
||||
self.connections[self.server_conn].safe_send_body(
|
||||
self.raise_zombie,
|
||||
self.server_stream_id,
|
||||
chunks,
|
||||
end_stream=(request.trailers is None),
|
||||
)
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_request(self, message):
|
||||
self.send_request_headers(message)
|
||||
self.send_request_body(message, [message.content])
|
||||
def send_request_trailers(self, request):
|
||||
self._send_trailers(self.server_conn, request.trailers)
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_request(self, request):
|
||||
self.send_request_headers(request)
|
||||
self.send_request_body(request, [request.content])
|
||||
self.send_request_trailers(request)
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_response_headers(self):
|
||||
self.response_arrived.wait()
|
||||
self.response_message.arrived.wait()
|
||||
|
||||
self.raise_zombie()
|
||||
|
||||
status_code = int(self.response_headers.get(':status', 502))
|
||||
headers = self.response_headers.copy()
|
||||
status_code = int(self.response_message.headers.get(':status', 502))
|
||||
headers = self.response_message.headers.copy()
|
||||
headers.pop(":status", None)
|
||||
|
||||
return http.HTTPResponse(
|
||||
@ -620,16 +648,20 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
def read_response_body(self, request, response):
|
||||
while True:
|
||||
try:
|
||||
yield self.response_data_queue.get(timeout=0.1)
|
||||
yield self.response_message.data_queue.get(timeout=0.1)
|
||||
except queue.Empty: # pragma: no cover
|
||||
pass
|
||||
if self.response_data_finished.is_set():
|
||||
if self.response_message.stream_ended.is_set():
|
||||
self.raise_zombie()
|
||||
while self.response_data_queue.qsize() > 0:
|
||||
yield self.response_data_queue.get()
|
||||
while self.response_message.data_queue.qsize() > 0:
|
||||
yield self.response_message.data_queue.get()
|
||||
break
|
||||
self.raise_zombie()
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_response_trailers(self, request, response):
|
||||
return self.response_message.trailers
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_response_headers(self, response):
|
||||
headers = response.headers.copy()
|
||||
@ -642,13 +674,29 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
)
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_response_body(self, _response, chunks):
|
||||
def send_response_body(self, response, chunks):
|
||||
self.connections[self.client_conn].safe_send_body(
|
||||
self.raise_zombie,
|
||||
self.client_stream_id,
|
||||
chunks
|
||||
chunks,
|
||||
end_stream=(response.trailers is None),
|
||||
)
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_response_trailers(self, response):
|
||||
self._send_trailers(self.client_conn, response.trailers)
|
||||
|
||||
def _send_trailers(self, conn, trailers):
|
||||
if not trailers:
|
||||
return
|
||||
with self.connections[conn].lock:
|
||||
self.connections[conn].safe_send_headers(
|
||||
self.raise_zombie,
|
||||
self.client_stream_id,
|
||||
trailers,
|
||||
end_stream=True
|
||||
)
|
||||
|
||||
def __call__(self): # pragma: no cover
|
||||
raise EnvironmentError('Http2SingleStreamLayer must be run as thread')
|
||||
|
||||
|
@ -93,6 +93,9 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
|
||||
"timestamp_end": flow.response.timestamp_end,
|
||||
"is_replay": flow.response.is_replay,
|
||||
}
|
||||
if flow.response.data.trailers:
|
||||
f["response"]["trailers"] = tuple(flow.response.data.trailers.items(True))
|
||||
|
||||
f.get("server_conn", {}).pop("cert", None)
|
||||
f.get("client_conn", {}).pop("mitmcert", None)
|
||||
|
||||
|
@ -8,7 +8,7 @@ MITMPROXY = "mitmproxy " + VERSION
|
||||
|
||||
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
|
||||
# for each change in the file format.
|
||||
FLOW_FORMAT_VERSION = 7
|
||||
FLOW_FORMAT_VERSION = 8
|
||||
|
||||
|
||||
def get_dev_version() -> str:
|
||||
|
@ -110,8 +110,9 @@ class HTTP2StateProtocol:
|
||||
b"HTTP/2.0",
|
||||
headers,
|
||||
body,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
None,
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
)
|
||||
request.stream_id = stream_id
|
||||
|
||||
|
@ -21,8 +21,11 @@ class TestRequestData:
|
||||
treq(headers="foobar")
|
||||
with pytest.raises(ValueError):
|
||||
treq(content="foobar")
|
||||
with pytest.raises(ValueError):
|
||||
treq(trailers="foobar")
|
||||
|
||||
assert isinstance(treq(headers=()).headers, Headers)
|
||||
assert isinstance(treq(trailers=()).trailers, Headers)
|
||||
|
||||
|
||||
class TestRequestCore:
|
||||
|
@ -20,8 +20,11 @@ class TestResponseData:
|
||||
tresp(reason="fööbär")
|
||||
with pytest.raises(ValueError):
|
||||
tresp(content="foobar")
|
||||
with pytest.raises(ValueError):
|
||||
tresp(trailers="foobar")
|
||||
|
||||
assert isinstance(tresp(headers=()).headers, Headers)
|
||||
assert isinstance(tresp(trailers=()).trailers, Headers)
|
||||
|
||||
|
||||
class TestResponseCore:
|
||||
|
@ -1031,3 +1031,147 @@ class TestResponseStreaming(_Http2Test):
|
||||
assert data
|
||||
else:
|
||||
assert data is None
|
||||
|
||||
|
||||
class TestRequestTrailers(_Http2Test):
|
||||
server_trailers_received = False
|
||||
|
||||
@classmethod
|
||||
def handle_server_event(cls, event, h2_conn, rfile, wfile):
|
||||
if isinstance(event, h2.events.RequestReceived):
|
||||
# reset the value for a fresh test
|
||||
cls.server_trailers_received = False
|
||||
elif isinstance(event, h2.events.ConnectionTerminated):
|
||||
return False
|
||||
elif isinstance(event, h2.events.TrailersReceived):
|
||||
cls.server_trailers_received = True
|
||||
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
h2_conn.send_headers(event.stream_id, [
|
||||
(':status', '200'),
|
||||
('x-my-trailer-request-received', 'success' if cls.server_trailers_received else "failure"),
|
||||
], end_stream=True)
|
||||
wfile.write(h2_conn.data_to_send())
|
||||
wfile.flush()
|
||||
return True
|
||||
|
||||
@pytest.mark.parametrize('announce', [True, False])
|
||||
@pytest.mark.parametrize('body', [None, b"foobar"])
|
||||
def test_trailers(self, announce, body):
|
||||
h2_conn = self.setup_connection()
|
||||
stream_id = 1
|
||||
headers = [
|
||||
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
|
||||
(':method', 'GET'),
|
||||
(':scheme', 'https'),
|
||||
(':path', '/'),
|
||||
]
|
||||
if announce:
|
||||
headers.append(('trailer', 'x-my-trailers'))
|
||||
h2_conn.send_headers(
|
||||
stream_id=stream_id,
|
||||
headers=headers,
|
||||
)
|
||||
if body:
|
||||
h2_conn.send_data(stream_id, body)
|
||||
|
||||
# send trailers
|
||||
h2_conn.send_headers(stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
|
||||
|
||||
self.client.wfile.write(h2_conn.data_to_send())
|
||||
self.client.wfile.flush()
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
try:
|
||||
raw = b''.join(http2.read_raw_frame(self.client.rfile))
|
||||
events = h2_conn.receive_data(raw)
|
||||
except exceptions.HttpException:
|
||||
print(traceback.format_exc())
|
||||
assert False
|
||||
|
||||
self.client.wfile.write(h2_conn.data_to_send())
|
||||
self.client.wfile.flush()
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.StreamEnded):
|
||||
done = True
|
||||
|
||||
h2_conn.close_connection()
|
||||
self.client.wfile.write(h2_conn.data_to_send())
|
||||
self.client.wfile.flush()
|
||||
|
||||
assert len(self.master.state.flows) == 1
|
||||
assert self.master.state.flows[0].request.trailers['x-my-trailers'] == 'foobar'
|
||||
assert self.master.state.flows[0].response.status_code == 200
|
||||
assert self.master.state.flows[0].response.headers['x-my-trailer-request-received'] == 'success'
|
||||
|
||||
|
||||
class TestResponseTrailers(_Http2Test):
|
||||
|
||||
@classmethod
|
||||
def handle_server_event(cls, event, h2_conn, rfile, wfile):
|
||||
if isinstance(event, h2.events.ConnectionTerminated):
|
||||
return False
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
headers = [
|
||||
(':status', '200'),
|
||||
]
|
||||
if event.stream_id == 1:
|
||||
# special stream_id to activate the Trailer announcement header
|
||||
headers.append(('trailer', 'x-my-trailers'))
|
||||
|
||||
h2_conn.send_headers(event.stream_id, headers)
|
||||
h2_conn.send_data(event.stream_id, b'response body')
|
||||
h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
|
||||
wfile.write(h2_conn.data_to_send())
|
||||
wfile.flush()
|
||||
return True
|
||||
|
||||
@pytest.mark.parametrize('announce', [True, False])
|
||||
def test_trailers(self, announce):
|
||||
response_body_buffer = b''
|
||||
h2_conn = self.setup_connection()
|
||||
|
||||
self._send_request(
|
||||
self.client.wfile,
|
||||
h2_conn,
|
||||
stream_id=(1 if announce else 3),
|
||||
headers=[
|
||||
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
|
||||
(':method', 'GET'),
|
||||
(':scheme', 'https'),
|
||||
(':path', '/'),
|
||||
])
|
||||
|
||||
trailers_buffer = None
|
||||
done = False
|
||||
while not done:
|
||||
try:
|
||||
raw = b''.join(http2.read_raw_frame(self.client.rfile))
|
||||
events = h2_conn.receive_data(raw)
|
||||
except exceptions.HttpException:
|
||||
print(traceback.format_exc())
|
||||
assert False
|
||||
|
||||
self.client.wfile.write(h2_conn.data_to_send())
|
||||
self.client.wfile.flush()
|
||||
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
response_body_buffer += event.data
|
||||
elif isinstance(event, h2.events.TrailersReceived):
|
||||
trailers_buffer = event.headers
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
done = True
|
||||
|
||||
h2_conn.close_connection()
|
||||
self.client.wfile.write(h2_conn.data_to_send())
|
||||
self.client.wfile.flush()
|
||||
|
||||
assert len(self.master.state.flows) == 1
|
||||
assert self.master.state.flows[0].response.status_code == 200
|
||||
assert self.master.state.flows[0].response.content == b'response body'
|
||||
assert response_body_buffer == b'response body'
|
||||
assert self.master.state.flows[0].response.trailers['x-my-trailers'] == 'foobar'
|
||||
assert trailers_buffer == [(b'x-my-trailers', b'foobar')]
|
||||
|
Loading…
Reference in New Issue
Block a user