mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] HTTP/2: respect max_concurrency_limit
This commit is contained in:
parent
41ed038bb9
commit
1112135920
@ -1,6 +1,7 @@
|
|||||||
|
import collections
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import ClassVar, Dict, Iterable, List, Tuple, Type, Union
|
from typing import ClassVar, DefaultDict, Dict, Iterable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import h2.config
|
import h2.config
|
||||||
import h2.connection
|
import h2.connection
|
||||||
@ -287,6 +288,10 @@ class Http2Client(Http2Connection):
|
|||||||
|
|
||||||
our_stream_id = Dict[int, int]
|
our_stream_id = Dict[int, int]
|
||||||
their_stream_id = Dict[int, int]
|
their_stream_id = Dict[int, int]
|
||||||
|
stream_queue = DefaultDict[int, List[Event]]
|
||||||
|
"""Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS"""
|
||||||
|
provisional_max_concurrency: Optional[int] = 10
|
||||||
|
"""A provisional currency limit before we get the server's first settings frame."""
|
||||||
|
|
||||||
def __init__(self, context: Context):
|
def __init__(self, context: Context):
|
||||||
super().__init__(context, context.server)
|
super().__init__(context, context.server)
|
||||||
@ -297,6 +302,7 @@ class Http2Client(Http2Connection):
|
|||||||
self.h2_conn.local_settings.acknowledge()
|
self.h2_conn.local_settings.acknowledge()
|
||||||
self.our_stream_id = {}
|
self.our_stream_id = {}
|
||||||
self.their_stream_id = {}
|
self.their_stream_id = {}
|
||||||
|
self.stream_queue = collections.defaultdict(list)
|
||||||
|
|
||||||
def _handle_event(self, event: Event) -> CommandGenerator[None]:
|
def _handle_event(self, event: Event) -> CommandGenerator[None]:
|
||||||
# We can't reuse stream ids from the client because they may arrived reordered here
|
# We can't reuse stream ids from the client because they may arrived reordered here
|
||||||
@ -305,6 +311,13 @@ class Http2Client(Http2Connection):
|
|||||||
if isinstance(event, HttpEvent):
|
if isinstance(event, HttpEvent):
|
||||||
ours = self.our_stream_id.get(event.stream_id, None)
|
ours = self.our_stream_id.get(event.stream_id, None)
|
||||||
if ours is None:
|
if ours is None:
|
||||||
|
no_free_streams = (
|
||||||
|
self.h2_conn.open_outbound_streams >=
|
||||||
|
(self.provisional_max_concurrency or self.h2_conn.remote_settings.max_concurrent_streams)
|
||||||
|
)
|
||||||
|
if no_free_streams:
|
||||||
|
self.stream_queue[event.stream_id].append(event)
|
||||||
|
return
|
||||||
ours = self.h2_conn.get_next_available_stream_id()
|
ours = self.h2_conn.get_next_available_stream_id()
|
||||||
self.our_stream_id[event.stream_id] = ours
|
self.our_stream_id[event.stream_id] = ours
|
||||||
self.their_stream_id[ours] = event.stream_id
|
self.their_stream_id[ours] = event.stream_id
|
||||||
@ -315,6 +328,18 @@ class Http2Client(Http2Connection):
|
|||||||
cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id]
|
cmd.event.stream_id = self.their_stream_id[cmd.event.stream_id]
|
||||||
yield cmd
|
yield cmd
|
||||||
|
|
||||||
|
can_resume_queue = (
|
||||||
|
self.stream_queue and
|
||||||
|
self.h2_conn.open_outbound_streams < (
|
||||||
|
self.provisional_max_concurrency or self.h2_conn.remote_settings.max_concurrent_streams
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if can_resume_queue:
|
||||||
|
# popitem would be LIFO, but we want FIFO.
|
||||||
|
events = self.stream_queue.pop(next(iter(self.stream_queue)))
|
||||||
|
for event in events:
|
||||||
|
yield from self._handle_event(event)
|
||||||
|
|
||||||
def _handle_event2(self, event: Event) -> CommandGenerator[None]:
|
def _handle_event2(self, event: Event) -> CommandGenerator[None]:
|
||||||
if isinstance(event, RequestHeaders):
|
if isinstance(event, RequestHeaders):
|
||||||
pseudo_headers = [
|
pseudo_headers = [
|
||||||
@ -369,6 +394,11 @@ class Http2Client(Http2Connection):
|
|||||||
elif isinstance(event, h2.events.RequestReceived):
|
elif isinstance(event, h2.events.RequestReceived):
|
||||||
yield from self.protocol_error(f"HTTP/2 protocol error: received request from server")
|
yield from self.protocol_error(f"HTTP/2 protocol error: received request from server")
|
||||||
return True
|
return True
|
||||||
|
elif isinstance(event, h2.events.RemoteSettingsChanged):
|
||||||
|
# We have received at least one settings from now,
|
||||||
|
# which means we can rely on the max concurrency in remote_settings
|
||||||
|
self.provisional_max_concurrency = None
|
||||||
|
return (yield from super().handle_h2_event(event))
|
||||||
else:
|
else:
|
||||||
return (yield from super().handle_h2_event(event))
|
return (yield from super().handle_h2_event(event))
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import h2.settings
|
||||||
import hpack
|
import hpack
|
||||||
import hyperframe.frame
|
import hyperframe.frame
|
||||||
import pytest
|
import pytest
|
||||||
@ -350,3 +351,52 @@ def test_stream_concurrency(tctx):
|
|||||||
hyperframe.frame.HeadersFrame,
|
hyperframe.frame.HeadersFrame,
|
||||||
hyperframe.frame.DataFrame
|
hyperframe.frame.DataFrame
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_concurrency(tctx):
|
||||||
|
playbook, cff = start_h2_client(tctx)
|
||||||
|
server = Placeholder(Server)
|
||||||
|
req1_bytes = Placeholder(bytes)
|
||||||
|
settings_ack_bytes = Placeholder(bytes)
|
||||||
|
req2_bytes = Placeholder(bytes)
|
||||||
|
playbook.hooks = False
|
||||||
|
sff = FrameFactory()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
playbook
|
||||||
|
>> DataReceived(tctx.client,
|
||||||
|
cff.build_headers_frame(example_request_headers, flags=["END_STREAM"],
|
||||||
|
stream_id=1).serialize())
|
||||||
|
<< OpenConnection(server)
|
||||||
|
>> reply(None, side_effect=make_h2)
|
||||||
|
<< SendData(server, req1_bytes)
|
||||||
|
>> DataReceived(server,
|
||||||
|
sff.build_settings_frame(
|
||||||
|
{h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 1}).serialize())
|
||||||
|
<< SendData(server, settings_ack_bytes)
|
||||||
|
>> DataReceived(tctx.client,
|
||||||
|
cff.build_headers_frame(example_request_headers,
|
||||||
|
flags=["END_STREAM"],
|
||||||
|
stream_id=3).serialize())
|
||||||
|
# Can't send it upstream yet, all streams in use!
|
||||||
|
>> DataReceived(server, sff.build_headers_frame(example_response_headers,
|
||||||
|
flags=["END_STREAM"],
|
||||||
|
stream_id=1).serialize())
|
||||||
|
# But now we can!
|
||||||
|
<< SendData(server, req2_bytes)
|
||||||
|
<< SendData(tctx.client, Placeholder(bytes))
|
||||||
|
>> DataReceived(server, sff.build_headers_frame(example_response_headers,
|
||||||
|
flags=["END_STREAM"],
|
||||||
|
stream_id=3).serialize())
|
||||||
|
<< SendData(tctx.client, Placeholder(bytes))
|
||||||
|
)
|
||||||
|
settings, req1 = decode_frames(req1_bytes())
|
||||||
|
settings_ack, = decode_frames(settings_ack_bytes())
|
||||||
|
req2, = decode_frames(req2_bytes())
|
||||||
|
|
||||||
|
assert type(settings) == hyperframe.frame.SettingsFrame
|
||||||
|
assert type(req1) == hyperframe.frame.HeadersFrame
|
||||||
|
assert type(settings_ack) == hyperframe.frame.SettingsFrame
|
||||||
|
assert type(req2) == hyperframe.frame.HeadersFrame
|
||||||
|
assert req1.stream_id == 1
|
||||||
|
assert req2.stream_id == 3
|
||||||
|
Loading…
Reference in New Issue
Block a user