mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 02:10:59 +00:00
unify HTTP trailers APIs
This commit is contained in:
parent
d589f13a1d
commit
ebb061796c
@ -26,6 +26,7 @@ class HTTPRequest(http.Request):
|
||||
http_version,
|
||||
headers,
|
||||
content,
|
||||
trailers=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
is_replay=False,
|
||||
@ -41,6 +42,7 @@ class HTTPRequest(http.Request):
|
||||
http_version,
|
||||
headers,
|
||||
content,
|
||||
trailers,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
)
|
||||
@ -73,6 +75,7 @@ class HTTPRequest(http.Request):
|
||||
http_version=request.data.http_version,
|
||||
headers=request.data.headers,
|
||||
content=request.data.content,
|
||||
trailers=request.data.trailers,
|
||||
timestamp_start=request.data.timestamp_start,
|
||||
timestamp_end=request.data.timestamp_end,
|
||||
)
|
||||
|
@ -174,6 +174,7 @@ def convert_6_7(data):
|
||||
|
||||
def convert_7_8(data):
|
||||
data["version"] = 8
|
||||
data["request"]["trailers"] = None
|
||||
data["response"]["trailers"] = None
|
||||
return data
|
||||
|
||||
|
@ -59,7 +59,7 @@ def read_request_head(rfile):
|
||||
timestamp_start = rfile.first_byte_timestamp
|
||||
|
||||
return request.Request(
|
||||
form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
|
||||
form, method, scheme, host, port, path, http_version, headers, None, None, timestamp_start
|
||||
)
|
||||
|
||||
|
||||
|
@ -134,6 +134,20 @@ class Message(serializable.Serializable):
|
||||
|
||||
content = property(get_content, set_content)
|
||||
|
||||
@property
|
||||
def trailers(self):
|
||||
"""
|
||||
Message trailers object
|
||||
|
||||
Returns:
|
||||
mitmproxy.net.http.Headers
|
||||
"""
|
||||
return self.data.trailers
|
||||
|
||||
@trailers.setter
|
||||
def trailers(self, h):
|
||||
self.data.trailers = h
|
||||
|
||||
@property
|
||||
def http_version(self):
|
||||
"""
|
||||
|
@ -29,6 +29,7 @@ class RequestData(message.MessageData):
|
||||
http_version,
|
||||
headers=(),
|
||||
content=None,
|
||||
trailers=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None
|
||||
):
|
||||
@ -46,6 +47,8 @@ class RequestData(message.MessageData):
|
||||
headers = nheaders.Headers(headers)
|
||||
if isinstance(content, str):
|
||||
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
||||
if trailers is not None and not isinstance(trailers, nheaders.Headers):
|
||||
trailers = nheaders.Headers(trailers)
|
||||
|
||||
self.first_line_format = first_line_format
|
||||
self.method = method
|
||||
@ -56,6 +59,7 @@ class RequestData(message.MessageData):
|
||||
self.http_version = http_version
|
||||
self.headers = headers
|
||||
self.content = content
|
||||
self.trailers = trailers
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
|
@ -34,6 +34,8 @@ class ResponseData(message.MessageData):
|
||||
headers = nheaders.Headers(headers)
|
||||
if isinstance(content, str):
|
||||
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
|
||||
if trailers is not None and not isinstance(trailers, nheaders.Headers):
|
||||
trailers = nheaders.Headers(trailers)
|
||||
|
||||
self.http_version = http_version
|
||||
self.status_code = status_code
|
||||
|
@ -20,6 +20,9 @@ class _HttpTransmissionLayer(base.Layer):
|
||||
def read_request_body(self, request):
|
||||
raise NotImplementedError()
|
||||
|
||||
def read_request_trailers(self, request):
|
||||
raise NotImplementedError()
|
||||
|
||||
def send_request(self, request):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -30,11 +33,15 @@ class _HttpTransmissionLayer(base.Layer):
|
||||
raise NotImplementedError()
|
||||
yield "this is a generator" # pragma: no cover
|
||||
|
||||
def read_response_trailers(self, request, response):
|
||||
raise NotImplementedError()
|
||||
|
||||
def read_response(self, request):
|
||||
response = self.read_response_headers()
|
||||
response.data.content = b"".join(
|
||||
self.read_response_body(request, response)
|
||||
)
|
||||
response.data.trailers = self.read_response_trailers(request, response)
|
||||
return response
|
||||
|
||||
def send_response(self, response):
|
||||
@ -42,6 +49,7 @@ class _HttpTransmissionLayer(base.Layer):
|
||||
raise exceptions.HttpException("Cannot assemble flow with missing content")
|
||||
self.send_response_headers(response)
|
||||
self.send_response_body(response, [response.data.content])
|
||||
self.send_response_trailers(response)
|
||||
|
||||
def send_response_headers(self, response):
|
||||
raise NotImplementedError()
|
||||
@ -49,6 +57,9 @@ class _HttpTransmissionLayer(base.Layer):
|
||||
def send_response_body(self, response, chunks):
|
||||
raise NotImplementedError()
|
||||
|
||||
def send_response_trailers(self, response, chunks):
|
||||
raise NotImplementedError()
|
||||
|
||||
def check_close_connection(self, f):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -255,6 +266,7 @@ class HttpLayer(base.Layer):
|
||||
f.request.data.content = b"".join(
|
||||
self.read_request_body(f.request)
|
||||
)
|
||||
f.request.data.trailers = self.read_request_trailers(f.request)
|
||||
f.request.timestamp_end = time.time()
|
||||
self.channel.ask("http_connect", f)
|
||||
|
||||
@ -282,6 +294,9 @@ class HttpLayer(base.Layer):
|
||||
f.request.data.content = None
|
||||
else:
|
||||
f.request.data.content = b"".join(self.read_request_body(request))
|
||||
|
||||
f.request.data.trailers = self.read_request_trailers(f.request)
|
||||
|
||||
request.timestamp_end = time.time()
|
||||
except exceptions.HttpException as e:
|
||||
# We optimistically guess there might be an HTTP client on the
|
||||
@ -348,6 +363,8 @@ class HttpLayer(base.Layer):
|
||||
else:
|
||||
self.send_request_body(f.request, [f.request.data.content])
|
||||
|
||||
self.send_request_trailers(f.request)
|
||||
|
||||
f.response = self.read_response_headers()
|
||||
|
||||
try:
|
||||
@ -406,10 +423,9 @@ class HttpLayer(base.Layer):
|
||||
# we now need to emulate the responseheaders hook.
|
||||
self.channel.ask("responseheaders", f)
|
||||
|
||||
f.response.data.trailers = self.read_response_trailers(f.request, f.response)
|
||||
|
||||
self.log("response", "debug", [repr(f.response)])
|
||||
# not support HTTP/1.1 trailers
|
||||
if f.request.http_version == "HTTP/2.0":
|
||||
f.response.data.trailers = self.read_trailers_headers()
|
||||
self.channel.ask("response", f)
|
||||
|
||||
if not f.response.stream:
|
||||
|
@ -23,6 +23,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
||||
human.parse_size(self.config.options.body_size_limit)
|
||||
)
|
||||
|
||||
def read_request_trailers(self, request):
|
||||
if "Trailer" in request.headers:
|
||||
# TODO: not implemented yet
|
||||
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
|
||||
return None
|
||||
|
||||
def send_request_headers(self, request):
|
||||
headers = http1.assemble_request_head(request)
|
||||
self.server_conn.wfile.write(headers)
|
||||
@ -33,7 +39,13 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
||||
self.server_conn.wfile.write(chunk)
|
||||
self.server_conn.wfile.flush()
|
||||
|
||||
def send_request_trailers(self, request):
|
||||
if "Trailer" in request.headers:
|
||||
# TODO: not implemented yet
|
||||
self.log("HTTP/1 request trailer headers are not implemented yet!", "warn")
|
||||
|
||||
def send_request(self, request):
|
||||
# TODO: this does not yet support request trailers
|
||||
self.server_conn.wfile.write(http1.assemble_request(request))
|
||||
self.server_conn.wfile.flush()
|
||||
|
||||
@ -49,6 +61,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
||||
human.parse_size(self.config.options.body_size_limit)
|
||||
)
|
||||
|
||||
def read_response_trailers(self, request, response):
|
||||
if "Trailer" in response.headers:
|
||||
# TODO: not implemented yet
|
||||
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
|
||||
return None
|
||||
|
||||
def send_response_headers(self, response):
|
||||
raw = http1.assemble_response_head(response)
|
||||
self.client_conn.wfile.write(raw)
|
||||
@ -59,6 +77,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer):
|
||||
self.client_conn.wfile.write(chunk)
|
||||
self.client_conn.wfile.flush()
|
||||
|
||||
def send_response_trailers(self, response):
|
||||
if "Trailer" in response.headers:
|
||||
# TODO: not implemented yet
|
||||
self.log("HTTP/1 trailer headers are not implemented yet!", "warn")
|
||||
return
|
||||
|
||||
def check_close_connection(self, flow):
|
||||
request_close = http1.connection_close(
|
||||
flow.request.http_version,
|
||||
|
@ -235,8 +235,10 @@ class Http2Layer(base.Layer):
|
||||
return True
|
||||
|
||||
def _handle_trailers(self, eid, event, is_server, other_conn):
|
||||
headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
||||
self.streams[eid].update_trailers(headers)
|
||||
trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers])
|
||||
# TODO: support request trailers as well!
|
||||
self.streams[eid].response_trailers = trailers
|
||||
self.streams[eid].response_trailers_arrived.set()
|
||||
return True
|
||||
|
||||
def _handle_remote_settings_changed(self, event, other_conn):
|
||||
@ -417,15 +419,17 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
self.request_data_queue: queue.Queue[bytes] = queue.Queue()
|
||||
self.request_queued_data_length = 0
|
||||
self.request_data_finished = threading.Event()
|
||||
self.request_trailers_arrived = threading.Event()
|
||||
self.request_trailers = None
|
||||
|
||||
self.response_arrived = threading.Event()
|
||||
self.response_data_queue: queue.Queue[bytes] = queue.Queue()
|
||||
self.response_queued_data_length = 0
|
||||
self.response_data_finished = threading.Event()
|
||||
self.response_trailers_arrived = threading.Event()
|
||||
self.response_trailers = None
|
||||
|
||||
self.no_body = False
|
||||
self.has_trailers = False
|
||||
self.trailers_header = None
|
||||
|
||||
self.priority_exclusive: bool
|
||||
self.priority_depends_on: Optional[int] = None
|
||||
@ -437,8 +441,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
self.zombie = time.time()
|
||||
self.request_data_finished.set()
|
||||
self.request_arrived.set()
|
||||
self.request_trailers_arrived.set()
|
||||
self.response_arrived.set()
|
||||
self.response_data_finished.set()
|
||||
self.response_trailers_arrived.set()
|
||||
|
||||
def connect(self): # pragma: no cover
|
||||
raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.")
|
||||
@ -526,6 +532,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
break
|
||||
self.raise_zombie()
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_request_trailers(self, request):
|
||||
if "trailer" in request.headers:
|
||||
self.request_trailers_arrived.wait()
|
||||
self.raise_zombie()
|
||||
return self.request_trailers
|
||||
return None
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_request_headers(self, request):
|
||||
if self.pushed:
|
||||
@ -600,25 +614,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
)
|
||||
|
||||
@detect_zombie_stream
|
||||
def update_trailers(self, headers):
|
||||
self.trailers_header = headers
|
||||
self.has_trailers = True
|
||||
def send_request_trailers(self, request):
|
||||
self._send_trailers(self.server_conn, self.request_trailers)
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_trailers_headers(self):
|
||||
if self.has_trailers and self.trailers_header:
|
||||
with self.connections[self.client_conn].lock:
|
||||
self.connections[self.client_conn].safe_send_headers(
|
||||
self.raise_zombie,
|
||||
self.client_stream_id,
|
||||
self.trailers_header,
|
||||
end_stream = True
|
||||
)
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_request(self, message):
|
||||
self.send_request_headers(message)
|
||||
self.send_request_body(message, [message.content])
|
||||
def send_request(self, request):
|
||||
self.send_request_headers(request)
|
||||
self.send_request_body(request, [request.content])
|
||||
self.send_request_trailers(request)
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_response_headers(self):
|
||||
@ -640,10 +643,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
timestamp_end=self.timestamp_end,
|
||||
)
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_trailers_headers(self):
|
||||
return self.trailers_header
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_response_body(self, request, response):
|
||||
while True:
|
||||
@ -658,6 +657,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
break
|
||||
self.raise_zombie()
|
||||
|
||||
@detect_zombie_stream
|
||||
def read_response_trailers(self, request, response):
|
||||
if "trailer" in response.headers:
|
||||
self.response_trailers_arrived.wait()
|
||||
self.raise_zombie()
|
||||
return self.response_trailers
|
||||
return None
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_response_headers(self, response):
|
||||
headers = response.headers.copy()
|
||||
@ -670,15 +677,28 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
|
||||
)
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_response_body(self, _response, chunks):
|
||||
def send_response_body(self, response, chunks):
|
||||
self.connections[self.client_conn].safe_send_body(
|
||||
self.raise_zombie,
|
||||
self.client_stream_id,
|
||||
chunks,
|
||||
end_stream = not self.has_trailers
|
||||
end_stream=("trailer" not in response.headers)
|
||||
)
|
||||
if self.has_trailers:
|
||||
self.send_trailers_headers()
|
||||
|
||||
@detect_zombie_stream
|
||||
def send_response_trailers(self, _response):
|
||||
self._send_trailers(self.client_conn, self.response_trailers)
|
||||
|
||||
def _send_trailers(self, conn, trailers):
|
||||
if not trailers:
|
||||
return
|
||||
with self.connections[conn].lock:
|
||||
self.connections[conn].safe_send_headers(
|
||||
self.raise_zombie,
|
||||
self.client_stream_id,
|
||||
trailers,
|
||||
end_stream=True
|
||||
)
|
||||
|
||||
def __call__(self): # pragma: no cover
|
||||
raise EnvironmentError('Http2SingleStreamLayer must be run as thread')
|
||||
|
@ -110,8 +110,9 @@ class HTTP2StateProtocol:
|
||||
b"HTTP/2.0",
|
||||
headers,
|
||||
body,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
None,
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
)
|
||||
request.stream_id = stream_id
|
||||
|
||||
|
@ -21,8 +21,11 @@ class TestRequestData:
|
||||
treq(headers="foobar")
|
||||
with pytest.raises(ValueError):
|
||||
treq(content="foobar")
|
||||
with pytest.raises(ValueError):
|
||||
treq(trailers="foobar")
|
||||
|
||||
assert isinstance(treq(headers=()).headers, Headers)
|
||||
assert isinstance(treq(trailers=()).trailers, Headers)
|
||||
|
||||
|
||||
class TestRequestCore:
|
||||
|
@ -20,8 +20,11 @@ class TestResponseData:
|
||||
tresp(reason="fööbär")
|
||||
with pytest.raises(ValueError):
|
||||
tresp(content="foobar")
|
||||
with pytest.raises(ValueError):
|
||||
tresp(trailers="foobar")
|
||||
|
||||
assert isinstance(tresp(headers=()).headers, Headers)
|
||||
assert isinstance(tresp(trailers=()).trailers, Headers)
|
||||
|
||||
|
||||
class TestResponseCore:
|
||||
|
@ -1034,37 +1034,19 @@ class TestResponseStreaming(_Http2Test):
|
||||
|
||||
|
||||
class TestTrailers(_Http2Test):
|
||||
request_body_buffer = b''
|
||||
|
||||
@classmethod
|
||||
def handle_server_event(cls, event, h2_conn, rfile, wfile):
|
||||
if isinstance(event, h2.events.ConnectionTerminated):
|
||||
return False
|
||||
elif isinstance(event, h2.events.RequestReceived):
|
||||
assert (b'self.client-foo', b'self.client-bar-1') in event.headers
|
||||
assert (b'self.client-foo', b'self.client-bar-2') in event.headers
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
import warnings
|
||||
with warnings.catch_warnings():
|
||||
# Ignore UnicodeWarning:
|
||||
# h2/utilities.py:64: UnicodeWarning: Unicode equal comparison
|
||||
# failed to convert both arguments to Unicode - interpreting
|
||||
# them as being unequal.
|
||||
# elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20:
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
h2_conn.send_headers(event.stream_id, [
|
||||
(':status', '200'),
|
||||
('server-foo', 'server-bar'),
|
||||
('föo', 'bär'),
|
||||
('X-Stream-ID', str(event.stream_id)),
|
||||
])
|
||||
h2_conn.send_headers(event.stream_id, [
|
||||
(':status', '200'),
|
||||
('trailer', 'x-my-trailers')
|
||||
])
|
||||
h2_conn.send_data(event.stream_id, b'response body')
|
||||
h2_conn.send_headers(event.stream_id, [('trailers', 'trailers-foo')], end_stream=True)
|
||||
h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True)
|
||||
wfile.write(h2_conn.data_to_send())
|
||||
wfile.flush()
|
||||
elif isinstance(event, h2.events.DataReceived):
|
||||
cls.request_body_buffer += event.data
|
||||
return True
|
||||
|
||||
def test_trailers(self):
|
||||
@ -1079,11 +1061,9 @@ class TestTrailers(_Http2Test):
|
||||
(':method', 'GET'),
|
||||
(':scheme', 'https'),
|
||||
(':path', '/'),
|
||||
('self.client-FoO', 'self.client-bar-1'),
|
||||
('self.client-FoO', 'self.client-bar-2'),
|
||||
],
|
||||
body=b'request body')
|
||||
])
|
||||
|
||||
trailers_buffer = None
|
||||
done = False
|
||||
while not done:
|
||||
try:
|
||||
@ -1099,6 +1079,8 @@ class TestTrailers(_Http2Test):
|
||||
for event in events:
|
||||
if isinstance(event, h2.events.DataReceived):
|
||||
response_body_buffer += event.data
|
||||
elif isinstance(event, h2.events.TrailersReceived):
|
||||
trailers_buffer = event.headers
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
done = True
|
||||
|
||||
@ -1108,9 +1090,7 @@ class TestTrailers(_Http2Test):
|
||||
|
||||
assert len(self.master.state.flows) == 1
|
||||
assert self.master.state.flows[0].response.status_code == 200
|
||||
assert self.master.state.flows[0].response.headers['server-foo'] == 'server-bar'
|
||||
assert self.master.state.flows[0].response.headers['föo'] == 'bär'
|
||||
assert self.master.state.flows[0].response.content == b'response body'
|
||||
assert self.request_body_buffer == b'request body'
|
||||
assert response_body_buffer == b'response body'
|
||||
assert self.master.state.flows[0].response.data.trailers['trailers'] == 'trailers-foo'
|
||||
assert self.master.state.flows[0].response.data.trailers['x-my-trailers'] == 'foobar'
|
||||
assert trailers_buffer == [(b'x-my-trailers', b'foobar')]
|
||||
|
Loading…
Reference in New Issue
Block a user