[sans-io] fix HTTP/2 stream cancellation

This commit is contained in:
Maximilian Hils 2020-12-09 17:45:01 +01:00
parent d32a5d5f33
commit 8ae2ab2aca
3 changed files with 42 additions and 9 deletions

View File

@ -93,7 +93,7 @@ class HttpStream(layer.Layer):
request_body_buf: bytes request_body_buf: bytes
response_body_buf: bytes response_body_buf: bytes
flow: http.HTTPFlow flow: http.HTTPFlow
stream_id: StreamId stream_id: StreamId = None
child_layer: typing.Optional[layer.Layer] = None child_layer: typing.Optional[layer.Layer] = None
@property @property
@ -341,13 +341,16 @@ class HttpStream(layer.Layer):
isinstance(event, RequestProtocolError) and isinstance(event, RequestProtocolError) and
self.client_state in (self.state_stream_request_body, self.state_done) self.client_state in (self.state_stream_request_body, self.state_done)
) )
response_hook_already_triggered = (
self.client_state == self.state_errored
or
self.server_state in (self.state_done, self.state_errored)
)
if is_client_error_but_we_already_talk_upstream: if is_client_error_but_we_already_talk_upstream:
yield SendHttp(event, self.context.server) yield SendHttp(event, self.context.server)
self.client_state = self.state_errored self.client_state = self.state_errored
response_hook_already_triggered = (
self.server_state in (self.state_done, self.state_errored)
)
if not response_hook_already_triggered: if not response_hook_already_triggered:
# We don't want to trigger both a response hook and an error hook, # We don't want to trigger both a response hook and an error hook,
# so we need to check if the response is done yet or not. # so we need to check if the response is done yet or not.
@ -357,7 +360,7 @@ class HttpStream(layer.Layer):
if (yield from self.check_killed(False)): if (yield from self.check_killed(False)):
return return
if isinstance(event, ResponseProtocolError): if isinstance(event, ResponseProtocolError) and self.client_state != self.state_errored:
yield SendHttp(event, self.context.client) yield SendHttp(event, self.context.client)
self.server_state = self.state_errored self.server_state = self.state_errored

View File

@ -317,7 +317,8 @@ class Http1Client(Http1Connection):
else: else:
pass # FIXME: protect against header size DoS pass # FIXME: protect against header size DoS
elif isinstance(event, events.ConnectionClosed): elif isinstance(event, events.ConnectionClosed):
yield commands.CloseConnection(self.conn) if self.conn.state & ConnectionState.CAN_WRITE:
yield commands.CloseConnection(self.conn)
if self.stream_id: if self.stream_id:
if self.buf: if self.buf:
yield ReceiveHttp(ResponseProtocolError(self.stream_id, yield ReceiveHttp(ResponseProtocolError(self.stream_id,

View File

@ -304,6 +304,38 @@ def test_rst_then_close(tctx):
assert flow().error.msg == "connection cancelled" assert flow().error.msg == "connection cancelled"
def test_cancel_then_server_disconnect(tctx):
"""
Test that we properly handle the case of the following event sequence:
- client cancels a stream
- we start an error hook
- server disconnects
- error hook completes.
"""
playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow)
server = Placeholder(Server)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b'GET / HTTP/1.1\r\nHost: example.com\r\n\r\n')
>> DataReceived(tctx.client, cff.build_rst_stream_frame(1, ErrorCodes.CANCEL).serialize())
<< CloseConnection(server)
<< http.HttpErrorHook(flow)
>> reply()
>> ConnectionClosed(server)
<< None
)
def test_stream_concurrency(tctx): def test_stream_concurrency(tctx):
"""Test that we can send an intercepted request with a lower stream id than one that has already been sent.""" """Test that we can send an intercepted request with a lower stream id than one that has already been sent."""
playbook, cff = start_h2_client(tctx) playbook, cff = start_h2_client(tctx)
@ -375,9 +407,6 @@ def test_max_concurrency(tctx):
flags=["END_STREAM"], flags=["END_STREAM"],
stream_id=3).serialize()) stream_id=3).serialize())
# Can't send it upstream yet, all streams in use! # Can't send it upstream yet, all streams in use!
)
assert (
playbook
>> DataReceived(server, sff.build_headers_frame(example_response_headers, >> DataReceived(server, sff.build_headers_frame(example_response_headers,
flags=["END_STREAM"], flags=["END_STREAM"],
stream_id=1).serialize()) stream_id=1).serialize())