HTTP/2 keepalive (#5144)

* Example addon for saving streamed data including a small bug fix to make it work.

* Revert "Example addon for saving streamed data including a small bug fix to make it work."

This reverts commit 02ab78def9a52eaca1a89d0757cd9475ce250eaa.

* Add https_ping_threshold option to enable keep-alive for HTTP/2 server connections by sending PING frames if the conection is idle longer than the threshold.

* Fixed test

* Fix test

* Adding pragma

* Moved timer logic to _http2.py

* Small code improvement

* Update mitmproxy/options.py

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Update mitmproxy/options.py

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Update mitmproxy/proxy/commands.py

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Update mitmproxy/proxy/commands.py

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Update mitmproxy/proxy/layers/http/_http2.py

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Sending Wakup back to right client

* Update mitmproxy/proxy/server.py

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Update mitmproxy/proxy/server.py

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Update mitmproxy/proxy/server.py

Co-authored-by: Maximilian Hils <github@maximilianhils.com>

* Incorporated suggested changes
Fixed almost all tests

* make `Wakeup` a `CommandCompleted` event.

This allows us to use it with `reply()` in tests,
which makes sure that the correct instance is reused.
 # Please enter the commit message for your changes. Lines starting

* nits

`typing.Set` for Python 3.8 compatibility and a few minor stylistic changes.

* nits nits

Co-authored-by: Maximilian Hils <github@maximilianhils.com>
Co-authored-by: Maximilian Hils <git@maximilianhils.com>
This commit is contained in:
EndUser509 2022-04-22 13:59:55 +02:00 committed by GitHub
parent a863f529ab
commit 35703b0b6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 155 additions and 8 deletions

View File

@ -4,6 +4,8 @@
* Replayed flows retain their current position in the flow list. * Replayed flows retain their current position in the flow list.
([#5227](https://github.com/mitmproxy/mitmproxy/issues/5227), @mhils) ([#5227](https://github.com/mitmproxy/mitmproxy/issues/5227), @mhils)
* Periodically send HTTP/2 ping frames to keep connections alive.
([#5046](https://github.com/mitmproxy/mitmproxy/issues/5046), @EndUser509)
* Console Performance Improvements * Console Performance Improvements
([#3427](https://github.com/mitmproxy/mitmproxy/issues/3427), @BkPHcgQL3V) ([#3427](https://github.com/mitmproxy/mitmproxy/issues/3427), @BkPHcgQL3V)
* Warn users if server side event responses are received without streaming. * Warn users if server side event responses are received without streaming.
@ -24,7 +26,8 @@
([#4278](https://github.com/mitmproxy/mitmproxy/issues/4278), @kjy00302) ([#4278](https://github.com/mitmproxy/mitmproxy/issues/4278), @kjy00302)
* Fix mitmweb export copy failed in non-secure domain. * Fix mitmweb export copy failed in non-secure domain.
([#5264](https://github.com/mitmproxy/mitmproxy/issues/5264), @Pactortester) ([#5264](https://github.com/mitmproxy/mitmproxy/issues/5264), @Pactortester)
* Added example script for manipulating cookies. (@WillahScott) * Added example script for manipulating cookies.
([#5278](https://github.com/mitmproxy/mitmproxy/issues/5278), @WillahScott)
## 19 March 2022: mitmproxy 8.0.0 ## 19 March 2022: mitmproxy 8.0.0

View File

@ -106,6 +106,14 @@ class Options(optmanager.OptManager):
"Enable/disable HTTP/2 support. " "Enable/disable HTTP/2 support. "
"HTTP/2 support is enabled by default.", "HTTP/2 support is enabled by default.",
) )
self.add_option(
"http2_ping_keepalive", int, 58,
"""
Send a PING frame if an HTTP/2 connection is idle for more than
the specified number of seconds to prevent the remote site from closing it.
Set to 0 to disable this feature.
"""
)
self.add_option( self.add_option(
"websocket", bool, True, "websocket", bool, True,
"Enable/disable WebSocket support. " "Enable/disable WebSocket support. "

View File

@ -39,6 +39,16 @@ class Command:
return f"{type(self).__name__}({repr(x)})" return f"{type(self).__name__}({repr(x)})"
class RequestWakeup(Command):
"""
Request a `Wakeup` event after the specified amount of seconds.
"""
delay: float
def __init__(self, delay: float):
self.delay = delay
class ConnectionCommand(Command): class ConnectionCommand(Command):
""" """
Commands involving a specific connection Commands involving a specific connection

View File

@ -121,3 +121,11 @@ class MessageInjected(Event, typing.Generic[T]):
""" """
flow: flow.Flow flow: flow.Flow
message: T message: T
@dataclass
class Wakeup(CommandCompleted):
"""
Event sent to layers that requested a wakeup using RequestWakeup.
"""
command: commands.RequestWakeup

View File

@ -708,6 +708,9 @@ class HttpLayer(layer.Layer):
yield from self.event_to_child(self.connections[self.context.client], event) yield from self.event_to_child(self.connections[self.context.client], event)
if self.mode is HTTPMode.upstream: if self.mode is HTTPMode.upstream:
self.context.server.via = server_spec.parse_with_mode(self.context.options.mode)[1] self.context.server.via = server_spec.parse_with_mode(self.context.options.mode)[1]
elif isinstance(event, events.Wakeup):
stream = self.command_sources.pop(event.command)
yield from self.event_to_child(stream, event)
elif isinstance(event, events.CommandCompleted): elif isinstance(event, events.CommandCompleted):
stream = self.command_sources.pop(event.command) stream = self.command_sources.pop(event.command)
yield from self.event_to_child(stream, event) yield from self.event_to_child(stream, event)
@ -765,7 +768,7 @@ class HttpLayer(layer.Layer):
# Streams may yield blocking commands, which ultimately generate CommandCompleted events. # Streams may yield blocking commands, which ultimately generate CommandCompleted events.
# Those need to be routed back to the correct stream, so we need to keep track of that. # Those need to be routed back to the correct stream, so we need to keep track of that.
if command.blocking: if command.blocking or isinstance(command, commands.RequestWakeup):
self.command_sources[command] = child self.command_sources[command] = child
if isinstance(command, ReceiveHttp): if isinstance(command, ReceiveHttp):

View File

@ -20,9 +20,9 @@ from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolE
ResponseEndOfMessage, ResponseHeaders, RequestTrailers, ResponseTrailers, ResponseProtocolError ResponseEndOfMessage, ResponseHeaders, RequestTrailers, ResponseTrailers, ResponseProtocolError
from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error
from ._http_h2 import BufferedH2Connection, H2ConnectionLogger from ._http_h2 import BufferedH2Connection, H2ConnectionLogger
from ...commands import CloseConnection, Log, SendData from ...commands import CloseConnection, Log, SendData, RequestWakeup
from ...context import Context from ...context import Context
from ...events import ConnectionClosed, DataReceived, Event, Start from ...events import ConnectionClosed, DataReceived, Event, Start, Wakeup
from ...layer import CommandGenerator from ...layer import CommandGenerator
from ...utils import expect from ...utils import expect
@ -252,7 +252,7 @@ class Http2Connection(HttpConnection):
self.streams.clear() self.streams.clear()
self._handle_event = self.done # type: ignore self._handle_event = self.done # type: ignore
@expect(DataReceived, HttpEvent, ConnectionClosed) @expect(DataReceived, HttpEvent, ConnectionClosed, Wakeup)
def done(self, _) -> CommandGenerator[None]: def done(self, _) -> CommandGenerator[None]:
yield from () yield from ()
@ -358,6 +358,8 @@ class Http2Client(Http2Connection):
"""Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS""" """Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS"""
provisional_max_concurrency: Optional[int] = 10 provisional_max_concurrency: Optional[int] = 10
"""A provisional currency limit before we get the server's first settings frame.""" """A provisional currency limit before we get the server's first settings frame."""
last_activity: float
"""Timestamp of when we've last seen network activity on this connection."""
def __init__(self, context: Context): def __init__(self, context: Context):
super().__init__(context, context.server) super().__init__(context, context.server)
@ -407,7 +409,30 @@ class Http2Client(Http2Connection):
yield from self._handle_event(event) yield from self._handle_event(event)
def _handle_event2(self, event: Event) -> CommandGenerator[None]: def _handle_event2(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, RequestHeaders): if isinstance(event, Wakeup):
send_ping_now = (
# add one second to avoid unnecessary roundtrip, we don't need to be super correct here.
time.time() - self.last_activity + 1 > self.context.options.http2_ping_keepalive
)
if send_ping_now:
# PING frames MUST contain 8 octets of opaque data in the payload.
# A sender can include any value it chooses and use those octets in any fashion.
self.last_activity = time.time()
self.h2_conn.ping(b"0" * 8)
data = self.h2_conn.data_to_send()
if data is not None:
yield Log(f"Send HTTP/2 keep-alive PING to {human.format_address(self.conn.peername)}")
yield SendData(self.conn, data)
time_until_next_ping = self.context.options.http2_ping_keepalive - (time.time() - self.last_activity)
yield RequestWakeup(time_until_next_ping)
return
self.last_activity = time.time()
if isinstance(event, Start):
if self.context.options.http2_ping_keepalive > 0:
yield RequestWakeup(self.context.options.http2_ping_keepalive)
yield from super()._handle_event(event)
elif isinstance(event, RequestHeaders):
pseudo_headers = [ pseudo_headers = [
(b':method', event.request.data.method), (b':method', event.request.data.method),
(b':scheme', event.request.data.scheme), (b':scheme', event.request.data.scheme),

View File

@ -76,11 +76,13 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
client: Client client: Client
max_conns: typing.DefaultDict[Address, asyncio.Semaphore] max_conns: typing.DefaultDict[Address, asyncio.Semaphore]
layer: layer.Layer layer: layer.Layer
wakeup_timer: typing.Set[asyncio.Task]
def __init__(self, context: Context) -> None: def __init__(self, context: Context) -> None:
self.client = context.client self.client = context.client
self.transports = {} self.transports = {}
self.max_conns = collections.defaultdict(lambda: asyncio.Semaphore(5)) self.max_conns = collections.defaultdict(lambda: asyncio.Semaphore(5))
self.wakeup_timer = set()
# Ask for the first layer right away. # Ask for the first layer right away.
# In a reverse proxy scenario, this is necessary as we would otherwise hang # In a reverse proxy scenario, this is necessary as we would otherwise hang
@ -120,6 +122,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
await asyncio.wait([handler]) await asyncio.wait([handler])
watch.cancel() watch.cancel()
while self.wakeup_timer:
timer = self.wakeup_timer.pop()
timer.cancel()
self.log("client disconnect") self.log("client disconnect")
self.client.timestamp_end = time.time() self.client.timestamp_end = time.time()
@ -199,6 +204,13 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
command.connection.timestamp_end = time.time() command.connection.timestamp_end = time.time()
await self.handle_hook(server_hooks.ServerDisconnectedHook(hook_data)) await self.handle_hook(server_hooks.ServerDisconnectedHook(hook_data))
async def wakeup(self, request: commands.RequestWakeup) -> None:
await asyncio.sleep(request.delay)
task = asyncio.current_task()
assert task is not None
self.wakeup_timer.discard(task)
self.server_event(events.Wakeup(request))
async def handle_connection(self, connection: Connection) -> None: async def handle_connection(self, connection: Connection) -> None:
""" """
Handle a connection for its entire lifetime. Handle a connection for its entire lifetime.
@ -297,6 +309,14 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
client=self.client.peername, client=self.client.peername,
) )
self.transports[command.connection] = ConnectionIO(handler=handler) self.transports[command.connection] = ConnectionIO(handler=handler)
elif isinstance(command, commands.RequestWakeup):
task = asyncio_utils.create_task(
self.wakeup(command),
name=f"wakeup timer ({command.delay:.1f}s)",
client=self.client.peername
)
assert task is not None
self.wakeup_timer.add(task)
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports: elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
pass # The connection has already been closed. pass # The connection has already been closed.
elif isinstance(command, commands.SendData): elif isinstance(command, commands.SendData):

View File

@ -4,13 +4,14 @@ import h2.settings
import hpack import hpack
import hyperframe.frame import hyperframe.frame
import pytest import pytest
import time
from h2.errors import ErrorCodes from h2.errors import ErrorCodes
from mitmproxy.connection import ConnectionState, Server from mitmproxy.connection import ConnectionState, Server
from mitmproxy.flow import Error from mitmproxy.flow import Error
from mitmproxy.http import HTTPFlow, Headers, Request from mitmproxy.http import HTTPFlow, Headers, Request
from mitmproxy.net.http import status_codes from mitmproxy.net.http import status_codes
from mitmproxy.proxy.commands import CloseConnection, Log, OpenConnection, SendData from mitmproxy.proxy.commands import CloseConnection, Log, OpenConnection, SendData, RequestWakeup
from mitmproxy.proxy.context import Context from mitmproxy.proxy.context import Context
from mitmproxy.proxy.events import ConnectionClosed, DataReceived from mitmproxy.proxy.events import ConnectionClosed, DataReceived
from mitmproxy.proxy.layers import http from mitmproxy.proxy.layers import http
@ -64,8 +65,9 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
return frames return frames
def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]: def start_h2_client(tctx: Context, keepalive: int = 0) -> Tuple[Playbook, FrameFactory]:
tctx.client.alpn = b"h2" tctx.client.alpn = b"h2"
tctx.options.http2_ping_keepalive = keepalive
frame_factory = FrameFactory() frame_factory = FrameFactory()
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
@ -679,6 +681,7 @@ def test_kill_stream(tctx):
class TestClient: class TestClient:
def test_no_data_on_closed_stream(self, tctx): def test_no_data_on_closed_stream(self, tctx):
tctx.options.http2_ping_keepalive = 0
frame_factory = FrameFactory() frame_factory = FrameFactory()
req = Request.make("GET", "http://example.com/") req = Request.make("GET", "http://example.com/")
resp = { resp = {
@ -780,3 +783,64 @@ def test_request_smuggling_te(tctx):
<< CloseConnection(tctx.client) << CloseConnection(tctx.client)
) )
assert b"Connection-specific header field present" in err() assert b"Connection-specific header field present" in err()
def test_request_keepalive(tctx, monkeypatch):
playbook, cff = start_h2_client(tctx, 58)
flow = Placeholder(HTTPFlow)
server = Placeholder(Server)
initial = Placeholder(bytes)
def advance_time(_):
t = time.time()
monkeypatch.setattr(time, "time", lambda: t + 60)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< RequestWakeup(58)
<< SendData(server, initial)
>> reply(to=-2, side_effect=advance_time)
<< SendData(server, b'\x00\x00\x08\x06\x00\x00\x00\x00\x0000000000') # ping frame
<< RequestWakeup(58)
)
def test_keepalive_disconnect(tctx, monkeypatch):
playbook, cff = start_h2_client(tctx, 58)
playbook.hooks = False
sff = FrameFactory()
server = Placeholder(Server)
wakeup_command = RequestWakeup(58)
http_response = (
sff.build_headers_frame(example_response_headers).serialize() +
sff.build_data_frame(b"", flags=["END_STREAM"]).serialize()
)
def advance_time(_):
t = time.time()
monkeypatch.setattr(time, "time", lambda: t + 60)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers, flags=["END_STREAM"]).serialize())
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< wakeup_command
<< SendData(server, Placeholder(bytes))
>> DataReceived(server, http_response)
<< SendData(tctx.client, Placeholder(bytes))
>> ConnectionClosed(server)
<< CloseConnection(server)
>> reply(to=wakeup_command, side_effect=advance_time)
<< None
)

View File

@ -220,6 +220,7 @@ def h2_frames(draw):
def h2_layer(opts): def h2_layer(opts):
tctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), opts) tctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), opts)
tctx.options.http2_ping_keepalive = 0
tctx.client.alpn = b"h2" tctx.client.alpn = b"h2"
layer = http.HttpLayer(tctx, HTTPMode.regular) layer = http.HttpLayer(tctx, HTTPMode.regular)

View File

@ -24,6 +24,7 @@ def event_types(events):
def h2_client(tctx: Context) -> Tuple[h2.connection.H2Connection, Playbook]: def h2_client(tctx: Context) -> Tuple[h2.connection.H2Connection, Playbook]:
tctx.client.alpn = b"h2" tctx.client.alpn = b"h2"
tctx.options.http2_ping_keepalive = 0
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
conn = h2.connection.H2Connection() conn = h2.connection.H2Connection()
@ -89,6 +90,7 @@ def test_h1_to_h2(tctx):
"""Test HTTP/1 -> HTTP/2 request translation""" """Test HTTP/1 -> HTTP/2 request translation"""
server = Placeholder(Server) server = Placeholder(Server)
flow = Placeholder(HTTPFlow) flow = Placeholder(HTTPFlow)
tctx.options.http2_ping_keepalive = 0
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))

View File

@ -13,6 +13,7 @@ def tconn() -> connection.Server:
def test_dataclasses(tconn): def test_dataclasses(tconn):
assert repr(commands.RequestWakeup(58))
assert repr(commands.SendData(tconn, b"foo")) assert repr(commands.SendData(tconn, b"foo"))
assert repr(commands.OpenConnection(tconn)) assert repr(commands.OpenConnection(tconn))
assert repr(commands.CloseConnection(tconn)) assert repr(commands.CloseConnection(tconn))

View File

@ -22,6 +22,7 @@ export interface OptionsState {
content_view_lines_cutoff: number content_view_lines_cutoff: number
export_preserve_original_ip: boolean export_preserve_original_ip: boolean
http2: boolean http2: boolean
http2_ping_keepalive: number
ignore_hosts: string[] ignore_hosts: string[]
intercept: string | undefined intercept: string | undefined
intercept_active: boolean intercept_active: boolean
@ -110,6 +111,7 @@ export const defaultState: OptionsState = {
content_view_lines_cutoff: 512, content_view_lines_cutoff: 512,
export_preserve_original_ip: false, export_preserve_original_ip: false,
http2: true, http2: true,
http2_ping_keepalive: 58,
ignore_hosts: [], ignore_hosts: [],
intercept: undefined, intercept: undefined,
intercept_active: false, intercept_active: false,