diff --git a/mitmproxy/http.py b/mitmproxy/http.py index bf38a78e8..2cd5bccc9 100644 --- a/mitmproxy/http.py +++ b/mitmproxy/http.py @@ -26,6 +26,7 @@ class HTTPRequest(http.Request): http_version, headers, content, + trailers=None, timestamp_start=None, timestamp_end=None, is_replay=False, @@ -41,6 +42,7 @@ class HTTPRequest(http.Request): http_version, headers, content, + trailers, timestamp_start, timestamp_end, ) @@ -73,6 +75,7 @@ class HTTPRequest(http.Request): http_version=request.data.http_version, headers=request.data.headers, content=request.data.content, + trailers=request.data.trailers, timestamp_start=request.data.timestamp_start, timestamp_end=request.data.timestamp_end, ) diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py index e9d74e1c6..16e157756 100644 --- a/mitmproxy/io/compat.py +++ b/mitmproxy/io/compat.py @@ -174,6 +174,7 @@ def convert_6_7(data): def convert_7_8(data): data["version"] = 8 + data["request"]["trailers"] = None data["response"]["trailers"] = None return data diff --git a/mitmproxy/net/http/http1/read.py b/mitmproxy/net/http/http1/read.py index 0f60c8f4b..ce2007ed9 100644 --- a/mitmproxy/net/http/http1/read.py +++ b/mitmproxy/net/http/http1/read.py @@ -59,7 +59,7 @@ def read_request_head(rfile): timestamp_start = rfile.first_byte_timestamp return request.Request( - form, method, scheme, host, port, path, http_version, headers, None, timestamp_start + form, method, scheme, host, port, path, http_version, headers, None, None, timestamp_start ) diff --git a/mitmproxy/net/http/message.py b/mitmproxy/net/http/message.py index 478d334e4..4a16f52aa 100644 --- a/mitmproxy/net/http/message.py +++ b/mitmproxy/net/http/message.py @@ -134,6 +134,20 @@ class Message(serializable.Serializable): content = property(get_content, set_content) + @property + def trailers(self): + """ + Message trailers object + + Returns: + mitmproxy.net.http.Headers + """ + return self.data.trailers + + @trailers.setter + def trailers(self, h): + self.data.trailers = h + @property def http_version(self): """ diff --git a/mitmproxy/net/http/request.py b/mitmproxy/net/http/request.py index ba699e2aa..243378cf4 100644 --- a/mitmproxy/net/http/request.py +++ b/mitmproxy/net/http/request.py @@ -29,6 +29,7 @@ class RequestData(message.MessageData): http_version, headers=(), content=None, + trailers=None, timestamp_start=None, timestamp_end=None ): @@ -46,6 +47,8 @@ class RequestData(message.MessageData): headers = nheaders.Headers(headers) if isinstance(content, str): raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) + if trailers is not None and not isinstance(trailers, nheaders.Headers): + trailers = nheaders.Headers(trailers) self.first_line_format = first_line_format self.method = method @@ -56,6 +59,7 @@ class RequestData(message.MessageData): self.http_version = http_version self.headers = headers self.content = content + self.trailers = trailers self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end diff --git a/mitmproxy/net/http/response.py b/mitmproxy/net/http/response.py index edd8d4a62..7cc41940f 100644 --- a/mitmproxy/net/http/response.py +++ b/mitmproxy/net/http/response.py @@ -34,6 +34,8 @@ class ResponseData(message.MessageData): headers = nheaders.Headers(headers) if isinstance(content, str): raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) + if trailers is not None and not isinstance(trailers, nheaders.Headers): + trailers = nheaders.Headers(trailers) self.http_version = http_version self.status_code = status_code diff --git a/mitmproxy/proxy/protocol/http.py b/mitmproxy/proxy/protocol/http.py index 7949e7e21..c2f3779df 100644 --- a/mitmproxy/proxy/protocol/http.py +++ b/mitmproxy/proxy/protocol/http.py @@ -20,6 +20,9 @@ class _HttpTransmissionLayer(base.Layer): def read_request_body(self, request): raise NotImplementedError() + def read_request_trailers(self, request): + raise NotImplementedError() + def send_request(self, request): raise NotImplementedError() @@ -30,11 +33,15 @@ class _HttpTransmissionLayer(base.Layer): raise NotImplementedError() yield "this is a generator" # pragma: no cover + def read_response_trailers(self, request, response): + raise NotImplementedError() + def read_response(self, request): response = self.read_response_headers() response.data.content = b"".join( self.read_response_body(request, response) ) + response.data.trailers = self.read_response_trailers(request, response) return response def send_response(self, response): @@ -42,6 +49,7 @@ class _HttpTransmissionLayer(base.Layer): raise exceptions.HttpException("Cannot assemble flow with missing content") self.send_response_headers(response) self.send_response_body(response, [response.data.content]) + self.send_response_trailers(response) def send_response_headers(self, response): raise NotImplementedError() @@ -49,6 +57,9 @@ class _HttpTransmissionLayer(base.Layer): def send_response_body(self, response, chunks): raise NotImplementedError() + def send_response_trailers(self, response, chunks): + raise NotImplementedError() + def check_close_connection(self, f): raise NotImplementedError() @@ -255,6 +266,7 @@ class HttpLayer(base.Layer): f.request.data.content = b"".join( self.read_request_body(f.request) ) + f.request.data.trailers = self.read_request_trailers(f.request) f.request.timestamp_end = time.time() self.channel.ask("http_connect", f) @@ -282,6 +294,9 @@ class HttpLayer(base.Layer): f.request.data.content = None else: f.request.data.content = b"".join(self.read_request_body(request)) + + f.request.data.trailers = self.read_request_trailers(f.request) + request.timestamp_end = time.time() except exceptions.HttpException as e: # We optimistically guess there might be an HTTP client on the @@ -348,6 +363,8 @@ class HttpLayer(base.Layer): else: self.send_request_body(f.request, [f.request.data.content]) + self.send_request_trailers(f.request) + f.response = self.read_response_headers() try: @@ -406,10 +423,9 @@ class HttpLayer(base.Layer): # we now need to emulate the responseheaders hook. self.channel.ask("responseheaders", f) + f.response.data.trailers = self.read_response_trailers(f.request, f.response) + self.log("response", "debug", [repr(f.response)]) - # not support HTTP/1.1 trailers - if f.request.http_version == "HTTP/2.0": - f.response.data.trailers = self.read_trailers_headers() self.channel.ask("response", f) if not f.response.stream: diff --git a/mitmproxy/proxy/protocol/http1.py b/mitmproxy/proxy/protocol/http1.py index 91f1e9b7d..5fc4efbaf 100644 --- a/mitmproxy/proxy/protocol/http1.py +++ b/mitmproxy/proxy/protocol/http1.py @@ -23,6 +23,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer): human.parse_size(self.config.options.body_size_limit) ) + def read_request_trailers(self, request): + if "Trailer" in request.headers: + # TODO: not implemented yet + self.log("HTTP/1 request trailer headers are not implemented yet!", "warn") + return None + def send_request_headers(self, request): headers = http1.assemble_request_head(request) self.server_conn.wfile.write(headers) @@ -33,7 +39,13 @@ class Http1Layer(httpbase._HttpTransmissionLayer): self.server_conn.wfile.write(chunk) self.server_conn.wfile.flush() + def send_request_trailers(self, request): + if "Trailer" in request.headers: + # TODO: not implemented yet + self.log("HTTP/1 request trailer headers are not implemented yet!", "warn") + def send_request(self, request): + # TODO: this does not yet support request trailers self.server_conn.wfile.write(http1.assemble_request(request)) self.server_conn.wfile.flush() @@ -49,6 +61,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer): human.parse_size(self.config.options.body_size_limit) ) + def read_response_trailers(self, request, response): + if "Trailer" in response.headers: + # TODO: not implemented yet + self.log("HTTP/1 trailer headers are not implemented yet!", "warn") + return None + def send_response_headers(self, response): raw = http1.assemble_response_head(response) self.client_conn.wfile.write(raw) @@ -59,6 +77,12 @@ class Http1Layer(httpbase._HttpTransmissionLayer): self.client_conn.wfile.write(chunk) self.client_conn.wfile.flush() + def send_response_trailers(self, response): + if "Trailer" in response.headers: + # TODO: not implemented yet + self.log("HTTP/1 trailer headers are not implemented yet!", "warn") + return + def check_close_connection(self, flow): request_close = http1.connection_close( flow.request.http_version, diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index f5ab09f09..9275e6bdb 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -235,8 +235,10 @@ class Http2Layer(base.Layer): return True def _handle_trailers(self, eid, event, is_server, other_conn): - headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) - self.streams[eid].update_trailers(headers) + trailers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) + # TODO: support request trailers as well! + self.streams[eid].response_trailers = trailers + self.streams[eid].response_trailers_arrived.set() return True def _handle_remote_settings_changed(self, event, other_conn): @@ -417,15 +419,17 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr self.request_data_queue: queue.Queue[bytes] = queue.Queue() self.request_queued_data_length = 0 self.request_data_finished = threading.Event() + self.request_trailers_arrived = threading.Event() + self.request_trailers = None self.response_arrived = threading.Event() self.response_data_queue: queue.Queue[bytes] = queue.Queue() self.response_queued_data_length = 0 self.response_data_finished = threading.Event() + self.response_trailers_arrived = threading.Event() + self.response_trailers = None self.no_body = False - self.has_trailers = False - self.trailers_header = None self.priority_exclusive: bool self.priority_depends_on: Optional[int] = None @@ -437,8 +441,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr self.zombie = time.time() self.request_data_finished.set() self.request_arrived.set() + self.request_trailers_arrived.set() self.response_arrived.set() self.response_data_finished.set() + self.response_trailers_arrived.set() def connect(self): # pragma: no cover raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.") @@ -526,6 +532,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr break self.raise_zombie() + @detect_zombie_stream + def read_request_trailers(self, request): + if "trailer" in request.headers: + self.request_trailers_arrived.wait() + self.raise_zombie() + return self.request_trailers + return None + @detect_zombie_stream def send_request_headers(self, request): if self.pushed: @@ -600,25 +614,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr ) @detect_zombie_stream - def update_trailers(self, headers): - self.trailers_header = headers - self.has_trailers = True + def send_request_trailers(self, request): + self._send_trailers(self.server_conn, self.request_trailers) @detect_zombie_stream - def send_trailers_headers(self): - if self.has_trailers and self.trailers_header: - with self.connections[self.client_conn].lock: - self.connections[self.client_conn].safe_send_headers( - self.raise_zombie, - self.client_stream_id, - self.trailers_header, - end_stream = True - ) - - @detect_zombie_stream - def send_request(self, message): - self.send_request_headers(message) - self.send_request_body(message, [message.content]) + def send_request(self, request): + self.send_request_headers(request) + self.send_request_body(request, [request.content]) + self.send_request_trailers(request) @detect_zombie_stream def read_response_headers(self): @@ -640,10 +643,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr timestamp_end=self.timestamp_end, ) - @detect_zombie_stream - def read_trailers_headers(self): - return self.trailers_header - @detect_zombie_stream def read_response_body(self, request, response): while True: @@ -658,6 +657,14 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr break self.raise_zombie() + @detect_zombie_stream + def read_response_trailers(self, request, response): + if "trailer" in response.headers: + self.response_trailers_arrived.wait() + self.raise_zombie() + return self.response_trailers + return None + @detect_zombie_stream def send_response_headers(self, response): headers = response.headers.copy() @@ -670,15 +677,28 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr ) @detect_zombie_stream - def send_response_body(self, _response, chunks): + def send_response_body(self, response, chunks): self.connections[self.client_conn].safe_send_body( self.raise_zombie, self.client_stream_id, chunks, - end_stream = not self.has_trailers + end_stream=("trailer" not in response.headers) ) - if self.has_trailers: - self.send_trailers_headers() + + @detect_zombie_stream + def send_response_trailers(self, _response): + self._send_trailers(self.client_conn, self.response_trailers) + + def _send_trailers(self, conn, trailers): + if not trailers: + return + with self.connections[conn].lock: + self.connections[conn].safe_send_headers( + self.raise_zombie, + self.client_stream_id, + trailers, + end_stream=True + ) def __call__(self): # pragma: no cover raise EnvironmentError('Http2SingleStreamLayer must be run as thread') diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index c56d304db..748893ee2 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -110,8 +110,9 @@ class HTTP2StateProtocol: b"HTTP/2.0", headers, body, - timestamp_start, - timestamp_end, + None, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, ) request.stream_id = stream_id diff --git a/test/mitmproxy/net/http/test_request.py b/test/mitmproxy/net/http/test_request.py index 71d5c7a12..30129d331 100644 --- a/test/mitmproxy/net/http/test_request.py +++ b/test/mitmproxy/net/http/test_request.py @@ -21,8 +21,11 @@ class TestRequestData: treq(headers="foobar") with pytest.raises(ValueError): treq(content="foobar") + with pytest.raises(ValueError): + treq(trailers="foobar") assert isinstance(treq(headers=()).headers, Headers) + assert isinstance(treq(trailers=()).trailers, Headers) class TestRequestCore: diff --git a/test/mitmproxy/net/http/test_response.py b/test/mitmproxy/net/http/test_response.py index 08d72840e..7eb3eab82 100644 --- a/test/mitmproxy/net/http/test_response.py +++ b/test/mitmproxy/net/http/test_response.py @@ -20,8 +20,11 @@ class TestResponseData: tresp(reason="fööbär") with pytest.raises(ValueError): tresp(content="foobar") + with pytest.raises(ValueError): + tresp(trailers="foobar") assert isinstance(tresp(headers=()).headers, Headers) + assert isinstance(tresp(trailers=()).trailers, Headers) class TestResponseCore: diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index d6870de45..1529e7317 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -1034,37 +1034,19 @@ class TestResponseStreaming(_Http2Test): class TestTrailers(_Http2Test): - request_body_buffer = b'' - @classmethod def handle_server_event(cls, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False - elif isinstance(event, h2.events.RequestReceived): - assert (b'self.client-foo', b'self.client-bar-1') in event.headers - assert (b'self.client-foo', b'self.client-bar-2') in event.headers elif isinstance(event, h2.events.StreamEnded): - import warnings - with warnings.catch_warnings(): - # Ignore UnicodeWarning: - # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison - # failed to convert both arguments to Unicode - interpreting - # them as being unequal. - # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: - - warnings.simplefilter("ignore") - h2_conn.send_headers(event.stream_id, [ - (':status', '200'), - ('server-foo', 'server-bar'), - ('föo', 'bär'), - ('X-Stream-ID', str(event.stream_id)), - ]) + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('trailer', 'x-my-trailers') + ]) h2_conn.send_data(event.stream_id, b'response body') - h2_conn.send_headers(event.stream_id, [('trailers', 'trailers-foo')], end_stream=True) + h2_conn.send_headers(event.stream_id, [('x-my-trailers', 'foobar')], end_stream=True) wfile.write(h2_conn.data_to_send()) wfile.flush() - elif isinstance(event, h2.events.DataReceived): - cls.request_body_buffer += event.data return True def test_trailers(self): @@ -1079,11 +1061,9 @@ class TestTrailers(_Http2Test): (':method', 'GET'), (':scheme', 'https'), (':path', '/'), - ('self.client-FoO', 'self.client-bar-1'), - ('self.client-FoO', 'self.client-bar-2'), - ], - body=b'request body') + ]) + trailers_buffer = None done = False while not done: try: @@ -1099,6 +1079,8 @@ class TestTrailers(_Http2Test): for event in events: if isinstance(event, h2.events.DataReceived): response_body_buffer += event.data + elif isinstance(event, h2.events.TrailersReceived): + trailers_buffer = event.headers elif isinstance(event, h2.events.StreamEnded): done = True @@ -1108,9 +1090,7 @@ class TestTrailers(_Http2Test): assert len(self.master.state.flows) == 1 assert self.master.state.flows[0].response.status_code == 200 - assert self.master.state.flows[0].response.headers['server-foo'] == 'server-bar' - assert self.master.state.flows[0].response.headers['föo'] == 'bär' assert self.master.state.flows[0].response.content == b'response body' - assert self.request_body_buffer == b'request body' assert response_body_buffer == b'response body' - assert self.master.state.flows[0].response.data.trailers['trailers'] == 'trailers-foo' + assert self.master.state.flows[0].response.data.trailers['x-my-trailers'] == 'foobar' + assert trailers_buffer == [(b'x-my-trailers', b'foobar')]