diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index 985544ff1..d0276160d 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -17,10 +17,10 @@ from mitmproxy.proxy.utils import expect from mitmproxy.utils import human from mitmproxy.websocket import WebSocketData from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId -from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, ResponseTrailers, \ - ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError +from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, RequestTrailers, \ + ResponseTrailers, ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \ - HttpResponseHook, HttpResponseTrailersHook + HttpResponseHook, HttpRequestTrailersHook, HttpResponseTrailersHook from ._http1 import Http1Client, Http1Server from ._http2 import Http2Client, Http2Server from ...context import Context @@ -129,7 +129,7 @@ class HttpStream(layer.Layer): self.client_state = self.state_wait_for_request_headers elif isinstance(event, (RequestProtocolError, ResponseProtocolError)): yield from self.handle_protocol_error(event) - elif isinstance(event, (RequestHeaders, RequestData, RequestEndOfMessage)): + elif isinstance(event, (RequestHeaders, RequestData, RequestTrailers, RequestEndOfMessage)): yield from self.client_state(event) else: yield from self.server_state(event) @@ -221,11 +221,15 @@ class HttpStream(layer.Layer): ) self.client_state = self.state_stream_request_body - @expect(RequestData, RequestEndOfMessage) + @expect(RequestData, RequestTrailers, RequestEndOfMessage) def state_stream_request_body(self, event: Union[RequestData, RequestEndOfMessage]) -> layer.CommandGenerator[None]: if isinstance(event, RequestData): if callable(self.flow.request.stream): event.data = self.flow.request.stream(event.data) + elif isinstance(event, RequestTrailers): + assert self.flow.request + self.flow.request.trailers = event.trailers + yield HttpRequestTrailersHook(self.flow) elif isinstance(event, RequestEndOfMessage): self.flow.request.timestamp_end = time.time() self.client_state = self.state_done @@ -240,12 +244,18 @@ class HttpStream(layer.Layer): for evt in self._paused_event_queue: if isinstance(evt, ResponseProtocolError): return + if self.flow.request.trailers: + yield SendHttp(RequestTrailers(self.stream_id, self.flow.request.trailers), self.context.server) yield SendHttp(event, self.context.server) - @expect(RequestData, RequestEndOfMessage) + @expect(RequestData, RequestTrailers, RequestEndOfMessage) def state_consume_request_body(self, event: events.Event) -> layer.CommandGenerator[None]: if isinstance(event, RequestData): self.request_body_buf += event.data + elif isinstance(event, RequestTrailers): + assert self.flow.request + self.flow.request.trailers = event.trailers + yield HttpRequestTrailersHook(self.flow) elif isinstance(event, RequestEndOfMessage): self.flow.request.timestamp_end = time.time() self.flow.request.data.content = self.request_body_buf @@ -271,6 +281,8 @@ class HttpStream(layer.Layer): yield SendHttp(RequestHeaders(self.stream_id, self.flow.request, not content), self.context.server) if content: yield SendHttp(RequestData(self.stream_id, content), self.context.server) + if self.flow.request.trailers: + yield SendHttp(RequestTrailers(self.stream_id, self.flow.request.trailers), self.context.server) yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server) @expect(ResponseHeaders) @@ -285,7 +297,7 @@ class HttpStream(layer.Layer): else: self.server_state = self.state_consume_response_body - @expect(ResponseData, ResponseEndOfMessage) + @expect(ResponseData, ResponseTrailers, ResponseEndOfMessage) def state_stream_response_body(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.flow.response if isinstance(event, ResponseData): @@ -294,6 +306,10 @@ class HttpStream(layer.Layer): else: data = event.data yield SendHttp(ResponseData(self.stream_id, data), self.context.client) + elif isinstance(event, ResponseTrailers): + assert self.flow.response + self.flow.response.trailers = event.trailers + yield HttpResponseTrailersHook(self.flow) elif isinstance(event, ResponseEndOfMessage): yield from self.send_response(already_streamed=True) @@ -341,7 +357,7 @@ class HttpStream(layer.Layer): if content: yield SendHttp(ResponseData(self.stream_id, content), self.context.client) if self.flow.response.trailers: - yield SendHttp(ResponseTrailers(self.stream_id, self.flow.response.trailers, end_stream=True), self.context.client) + yield SendHttp(ResponseTrailers(self.stream_id, self.flow.response.trailers), self.context.client) yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client) diff --git a/mitmproxy/proxy/layers/http/_events.py b/mitmproxy/proxy/layers/http/_events.py index 71c6d6666..398c71905 100644 --- a/mitmproxy/proxy/layers/http/_events.py +++ b/mitmproxy/proxy/layers/http/_events.py @@ -45,10 +45,14 @@ class ResponseData(HttpEvent): self.data = data +@dataclass +class RequestTrailers(HttpEvent): + trailers: http.Headers + + @dataclass class ResponseTrailers(HttpEvent): trailers: http.Headers - end_stream: bool = True @dataclass @@ -92,6 +96,7 @@ __all__ = [ "RequestEndOfMessage", "ResponseHeaders", "ResponseData", + "RequestTrailers", "ResponseTrailers", "ResponseEndOfMessage", "RequestProtocolError", diff --git a/mitmproxy/proxy/layers/http/_hooks.py b/mitmproxy/proxy/layers/http/_hooks.py index f353a254e..268b6be60 100644 --- a/mitmproxy/proxy/layers/http/_hooks.py +++ b/mitmproxy/proxy/layers/http/_hooks.py @@ -45,10 +45,29 @@ class HttpResponseHook(commands.StartHook): flow: http.HTTPFlow +@dataclass +class HttpRequestTrailersHook(commands.StartHook): + """ + The HTTP request trailers has been read. + HTTP trailers are a rarely-used feature in the HTTP specification + which allows peers to send additional headers after the message body. + This is useful for metadata that is dynamically generated while + the message body is sent, for example a digital signature + or post-processing status. + """ + name = "requesttrailers" + flow: http.HTTPFlow + + @dataclass class HttpResponseTrailersHook(commands.StartHook): """ The HTTP response trailers has been read. + HTTP trailers are a rarely-used feature in the HTTP specification + which allows peers to send additional headers after the message body. + This is useful for metadata that is dynamically generated while + the message body is sent, for example a digital signature + or post-processing status. """ name = "responsetrailers" flow: http.HTTPFlow diff --git a/mitmproxy/proxy/layers/http/_http2.py b/mitmproxy/proxy/layers/http/_http2.py index 0d1923eb8..9c24f3c6c 100644 --- a/mitmproxy/proxy/layers/http/_http2.py +++ b/mitmproxy/proxy/layers/http/_http2.py @@ -17,7 +17,7 @@ from mitmproxy.connection import Connection from mitmproxy.net.http import status_codes, url from mitmproxy.utils import human from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ - ResponseEndOfMessage, ResponseHeaders, ResponseTrailers, ResponseProtocolError + ResponseEndOfMessage, ResponseHeaders, RequestTrailers, ResponseTrailers, ResponseProtocolError from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error from ._http_h2 import BufferedH2Connection, H2ConnectionLogger from ...commands import CloseConnection, Log, SendData @@ -97,10 +97,10 @@ class Http2Connection(HttpConnection): assert isinstance(event, (RequestData, ResponseData)) if self.is_open_for_us(event.stream_id): self.h2_conn.send_data(event.stream_id, event.data) - elif isinstance(event, ResponseTrailers): + elif isinstance(event, (RequestTrailers, ResponseTrailers)): if self.is_open_for_us(event.stream_id): trailers = [*event.trailers.fields] - self.h2_conn.send_headers(event.stream_id, trailers, event.end_stream) + self.h2_conn.send_headers(event.stream_id, trailers, end_stream=True) elif isinstance(event, self.SendEndOfMessage): if self.is_open_for_us(event.stream_id): self.h2_conn.end_stream(event.stream_id) @@ -220,7 +220,7 @@ class Http2Connection(HttpConnection): elif isinstance(event, h2.events.PingAckReceived): pass elif isinstance(event, h2.events.TrailersReceived): - yield Log("Received HTTP/2 request trailers, which are currently unimplemented and silently discarded", "error") + pass elif isinstance(event, h2.events.PushedStreamReceived): yield Log("Received HTTP/2 push promise, even though we signalled no support.", "error") elif isinstance(event, h2.events.UnknownFrameReceived): @@ -326,6 +326,10 @@ class Http2Server(Http2Connection): self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED yield ReceiveHttp(RequestHeaders(event.stream_id, request, end_stream=bool(event.stream_ended))) return False + elif isinstance(event, h2.events.TrailersReceived): + trailers = http.Headers(event.headers) + yield ReceiveHttp(RequestTrailers(event.stream_id, trailers)) + return False else: return (yield from super().handle_h2_event(event)) @@ -453,8 +457,8 @@ class Http2Client(Http2Connection): yield ReceiveHttp(ResponseHeaders(event.stream_id, response, bool(event.stream_ended))) return False elif isinstance(event, h2.events.TrailersReceived): - pseudo_trailers, trailers = split_pseudo_headers(event.headers) - yield ReceiveHttp(ResponseTrailers(event.stream_id, trailers, bool(event.stream_ended))) + trailers = http.Headers(event.headers) + yield ReceiveHttp(ResponseTrailers(event.stream_id, trailers)) return False elif isinstance(event, h2.events.RequestReceived): yield from self.protocol_error(f"HTTP/2 protocol error: received request from server") diff --git a/test/mitmproxy/proxy/layers/http/test_http2.py b/test/mitmproxy/proxy/layers/http/test_http2.py index 94c3c4808..45f181163 100644 --- a/test/mitmproxy/proxy/layers/http/test_http2.py +++ b/test/mitmproxy/proxy/layers/http/test_http2.py @@ -30,9 +30,14 @@ example_response_headers = ( (b':status', b'200'), ) +example_request_trailers = ( + (b'req-trailer-a', b'a'), + (b'req-trailer-b', b'b') +) + example_response_trailers = ( - (b'my-trailer-b', b'0'), - (b'my-trailer-b', b'0') + (b'resp-trailer-a', b'a'), + (b'resp-trailer-b', b'b') ) @@ -108,7 +113,7 @@ def test_simple(tctx): assert flow().response.text == "Hello, World!" -def test_trailers(tctx): +def test_response_trailers(tctx): playbook, cff = start_h2_client(tctx) flow = Placeholder(HTTPFlow) server = Placeholder(Server) @@ -152,6 +157,37 @@ def test_trailers(tctx): assert flow().response.text == "Hello, World!" +def test_request_trailers(tctx): + playbook, cff = start_h2_client(tctx) + flow = Placeholder(HTTPFlow) + server = Placeholder(Server) + initial = Placeholder(bytes) + assert ( + playbook + >> DataReceived(tctx.client, + cff.build_headers_frame(example_request_headers).serialize()) + << http.HttpRequestHeadersHook(flow) + >> reply() + >> DataReceived(tctx.client, cff.build_data_frame(b"Hello, World!").serialize()) + >> DataReceived(tctx.client, + cff.build_headers_frame(example_request_trailers, flags=["END_STREAM"]).serialize()) + << http.HttpRequestTrailersHook(flow) + >> reply() + << http.HttpRequestHook(flow) + >> reply() + << OpenConnection(server) + >> reply(None, side_effect=make_h2) + << SendData(server, initial) + ) + frames = decode_frames(initial()) + assert [type(x) for x in frames] == [ + hyperframe.frame.SettingsFrame, + hyperframe.frame.HeadersFrame, + hyperframe.frame.DataFrame, + hyperframe.frame.HeadersFrame, + ] + + def test_upstream_error(tctx): playbook, cff = start_h2_client(tctx) flow = Placeholder(HTTPFlow)