[sans-io] HTTP/2: respect max_concurrency_limit

This commit is contained in:
Maximilian Hils 2020-11-23 03:55:47 +01:00
parent 41ed038bb9
commit 1112135920
2 changed files with 81 additions and 1 deletions

View File

@ -1,6 +1,7 @@
import collections
import time import time
from enum import Enum from enum import Enum
from typing import ClassVar, Dict, Iterable, List, Tuple, Type, Union from typing import ClassVar, DefaultDict, Dict, Iterable, List, Optional, Tuple, Type, Union
import h2.config import h2.config
import h2.connection import h2.connection
@ -287,6 +288,10 @@ class Http2Client(Http2Connection):
our_stream_id = Dict[int, int] our_stream_id = Dict[int, int]
their_stream_id = Dict[int, int] their_stream_id = Dict[int, int]
stream_queue = DefaultDict[int, List[Event]]
"""Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS"""
provisional_max_concurrency: Optional[int] = 10
"""A provisional currency limit before we get the server's first settings frame."""
def __init__(self, context: Context): def __init__(self, context: Context):
super().__init__(context, context.server) super().__init__(context, context.server)
@ -297,6 +302,7 @@ class Http2Client(Http2Connection):
self.h2_conn.local_settings.acknowledge() self.h2_conn.local_settings.acknowledge()
self.our_stream_id = {} self.our_stream_id = {}
self.their_stream_id = {} self.their_stream_id = {}
self.stream_queue = collections.defaultdict(list)
def _handle_event(self, event: Event) -> CommandGenerator[None]: def _handle_event(self, event: Event) -> CommandGenerator[None]:
# We can't reuse stream ids from the client because they may arrived reordered here # We can't reuse stream ids from the client because they may arrived reordered here
@ -305,6 +311,13 @@ class Http2Client(Http2Connection):
if isinstance(event, HttpEvent): if isinstance(event, HttpEvent):
ours = self.our_stream_id.get(event.stream_id, None) ours = self.our_stream_id.get(event.stream_id, None)
if ours is None: if ours is None:
no_free_streams = (
self.h2_conn.open_outbound_streams >=
(self.provisional_max_concurrency or self.h2_conn.remote_settings.max_concurrent_streams)
)
if no_free_streams:
self.stream_queue[event.stream_id].append(event)
return
ours = self.h2_conn.get_next_available_stream_id() ours = self.h2_conn.get_next_available_stream_id()
self.our_stream_id[event.stream_id] = ours self.our_stream_id[event.stream_id] = ours
self.their_stream_id[ours] = event.stream_id self.their_stream_id[ours] = event.stream_id
@ -315,6 +328,18 @@ class Http2Client(Http2Connection):
cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id] cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id]
yield cmd yield cmd
can_resume_queue = (
self.stream_queue and
self.h2_conn.open_outbound_streams < (
self.provisional_max_concurrency or self.h2_conn.remote_settings.max_concurrent_streams
)
)
if can_resume_queue:
# popitem would be LIFO, but we want FIFO.
events = self.stream_queue.pop(next(iter(self.stream_queue)))
for event in events:
yield from self._handle_event(event)
def _handle_event2(self, event: Event) -> CommandGenerator[None]: def _handle_event2(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, RequestHeaders): if isinstance(event, RequestHeaders):
pseudo_headers = [ pseudo_headers = [
@ -369,6 +394,11 @@ class Http2Client(Http2Connection):
elif isinstance(event, h2.events.RequestReceived): elif isinstance(event, h2.events.RequestReceived):
yield from self.protocol_error(f"HTTP/2 protocol error: received request from server") yield from self.protocol_error(f"HTTP/2 protocol error: received request from server")
return True return True
elif isinstance(event, h2.events.RemoteSettingsChanged):
# We have received at least one settings from now,
# which means we can rely on the max concurrency in remote_settings
self.provisional_max_concurrency = None
return (yield from super().handle_h2_event(event))
else: else:
return (yield from super().handle_h2_event(event)) return (yield from super().handle_h2_event(event))

View File

@ -1,5 +1,6 @@
from typing import List, Tuple from typing import List, Tuple
import h2.settings
import hpack import hpack
import hyperframe.frame import hyperframe.frame
import pytest import pytest
@ -350,3 +351,52 @@ def test_stream_concurrency(tctx):
hyperframe.frame.HeadersFrame, hyperframe.frame.HeadersFrame,
hyperframe.frame.DataFrame hyperframe.frame.DataFrame
] ]
def test_max_concurrency(tctx):
playbook, cff = start_h2_client(tctx)
server = Placeholder(Server)
req1_bytes = Placeholder(bytes)
settings_ack_bytes = Placeholder(bytes)
req2_bytes = Placeholder(bytes)
playbook.hooks = False
sff = FrameFactory()
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers, flags=["END_STREAM"],
stream_id=1).serialize())
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, req1_bytes)
>> DataReceived(server,
sff.build_settings_frame(
{h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 1}).serialize())
<< SendData(server, settings_ack_bytes)
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers,
flags=["END_STREAM"],
stream_id=3).serialize())
# Can't send it upstream yet, all streams in use!
>> DataReceived(server, sff.build_headers_frame(example_response_headers,
flags=["END_STREAM"],
stream_id=1).serialize())
# But now we can!
<< SendData(server, req2_bytes)
<< SendData(tctx.client, Placeholder(bytes))
>> DataReceived(server, sff.build_headers_frame(example_response_headers,
flags=["END_STREAM"],
stream_id=3).serialize())
<< SendData(tctx.client, Placeholder(bytes))
)
settings, req1 = decode_frames(req1_bytes())
settings_ack, = decode_frames(settings_ack_bytes())
req2, = decode_frames(req2_bytes())
assert type(settings) == hyperframe.frame.SettingsFrame
assert type(req1) == hyperframe.frame.HeadersFrame
assert type(settings_ack) == hyperframe.frame.SettingsFrame
assert type(req2) == hyperframe.frame.HeadersFrame
assert req1.stream_id == 1
assert req2.stream_id == 3