From 8ae2ab2aca3215d0d07bd4d16673f6c8be4a024d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 9 Dec 2020 17:45:01 +0100 Subject: [PATCH] [sans-io] fix HTTP/2 stream cancellation --- mitmproxy/proxy2/layers/http/__init__.py | 13 ++++--- mitmproxy/proxy2/layers/http/_http1.py | 3 +- .../proxy2/layers/http/test_http2.py | 35 +++++++++++++++++-- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/mitmproxy/proxy2/layers/http/__init__.py b/mitmproxy/proxy2/layers/http/__init__.py index 5d135a7a8..4005f2ef1 100644 --- a/mitmproxy/proxy2/layers/http/__init__.py +++ b/mitmproxy/proxy2/layers/http/__init__.py @@ -93,7 +93,7 @@ class HttpStream(layer.Layer): request_body_buf: bytes response_body_buf: bytes flow: http.HTTPFlow - stream_id: StreamId + stream_id: StreamId = None child_layer: typing.Optional[layer.Layer] = None @property @@ -341,13 +341,16 @@ class HttpStream(layer.Layer): isinstance(event, RequestProtocolError) and self.client_state in (self.state_stream_request_body, self.state_done) ) + response_hook_already_triggered = ( + self.client_state == self.state_errored + or + self.server_state in (self.state_done, self.state_errored) + ) + if is_client_error_but_we_already_talk_upstream: yield SendHttp(event, self.context.server) self.client_state = self.state_errored - response_hook_already_triggered = ( - self.server_state in (self.state_done, self.state_errored) - ) if not response_hook_already_triggered: # We don't want to trigger both a response hook and an error hook, # so we need to check if the response is done yet or not. @@ -357,7 +360,7 @@ class HttpStream(layer.Layer): if (yield from self.check_killed(False)): return - if isinstance(event, ResponseProtocolError): + if isinstance(event, ResponseProtocolError) and self.client_state != self.state_errored: yield SendHttp(event, self.context.client) self.server_state = self.state_errored diff --git a/mitmproxy/proxy2/layers/http/_http1.py b/mitmproxy/proxy2/layers/http/_http1.py index 7ba89f8a5..865c5f3d9 100644 --- a/mitmproxy/proxy2/layers/http/_http1.py +++ b/mitmproxy/proxy2/layers/http/_http1.py @@ -317,7 +317,8 @@ class Http1Client(Http1Connection): else: pass # FIXME: protect against header size DoS elif isinstance(event, events.ConnectionClosed): - yield commands.CloseConnection(self.conn) + if self.conn.state & ConnectionState.CAN_WRITE: + yield commands.CloseConnection(self.conn) if self.stream_id: if self.buf: yield ReceiveHttp(ResponseProtocolError(self.stream_id, diff --git a/test/mitmproxy/proxy2/layers/http/test_http2.py b/test/mitmproxy/proxy2/layers/http/test_http2.py index 7bf31b82d..09637338a 100644 --- a/test/mitmproxy/proxy2/layers/http/test_http2.py +++ b/test/mitmproxy/proxy2/layers/http/test_http2.py @@ -304,6 +304,38 @@ def test_rst_then_close(tctx): assert flow().error.msg == "connection cancelled" +def test_cancel_then_server_disconnect(tctx): + """ + Test that we properly handle the case of the following event sequence: + - client cancels a stream + - we start an error hook + - server disconnects + - error hook completes. + """ + playbook, cff = start_h2_client(tctx) + flow = Placeholder(HTTPFlow) + server = Placeholder(Server) + + assert ( + playbook + >> DataReceived(tctx.client, + cff.build_headers_frame(example_request_headers, flags=["END_STREAM"]).serialize()) + << http.HttpRequestHeadersHook(flow) + >> reply() + << http.HttpRequestHook(flow) + >> reply() + << OpenConnection(server) + >> reply(None) + << SendData(server, b'GET / HTTP/1.1\r\nHost: example.com\r\n\r\n') + >> DataReceived(tctx.client, cff.build_rst_stream_frame(1, ErrorCodes.CANCEL).serialize()) + << CloseConnection(server) + << http.HttpErrorHook(flow) + >> reply() + >> ConnectionClosed(server) + << None + ) + + def test_stream_concurrency(tctx): """Test that we can send an intercepted request with a lower stream id than one that has already been sent.""" playbook, cff = start_h2_client(tctx) @@ -375,9 +407,6 @@ def test_max_concurrency(tctx): flags=["END_STREAM"], stream_id=3).serialize()) # Can't send it upstream yet, all streams in use! - ) - assert ( - playbook >> DataReceived(server, sff.build_headers_frame(example_response_headers, flags=["END_STREAM"], stream_id=1).serialize())