diff --git a/CHANGELOG.md b/CHANGELOG.md index 84f7bab65..f6edfef40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ * Replayed flows retain their current position in the flow list. ([#5227](https://github.com/mitmproxy/mitmproxy/issues/5227), @mhils) +* Periodically send HTTP/2 ping frames to keep connections alive. + ([#5046](https://github.com/mitmproxy/mitmproxy/issues/5046), @EndUser509) * Console Performance Improvements ([#3427](https://github.com/mitmproxy/mitmproxy/issues/3427), @BkPHcgQL3V) * Warn users if server side event responses are received without streaming. @@ -24,7 +26,8 @@ ([#4278](https://github.com/mitmproxy/mitmproxy/issues/4278), @kjy00302) * Fix mitmweb export copy failed in non-secure domain. ([#5264](https://github.com/mitmproxy/mitmproxy/issues/5264), @Pactortester) -* Added example script for manipulating cookies. (@WillahScott) +* Added example script for manipulating cookies. + ([#5278](https://github.com/mitmproxy/mitmproxy/issues/5278), @WillahScott) ## 19 March 2022: mitmproxy 8.0.0 diff --git a/mitmproxy/options.py b/mitmproxy/options.py index 14c90d220..e5716cacb 100644 --- a/mitmproxy/options.py +++ b/mitmproxy/options.py @@ -106,6 +106,14 @@ class Options(optmanager.OptManager): "Enable/disable HTTP/2 support. " "HTTP/2 support is enabled by default.", ) + self.add_option( + "http2_ping_keepalive", int, 58, + """ + Send a PING frame if an HTTP/2 connection is idle for more than + the specified number of seconds to prevent the remote site from closing it. + Set to 0 to disable this feature. + """ + ) self.add_option( "websocket", bool, True, "Enable/disable WebSocket support. " diff --git a/mitmproxy/proxy/commands.py b/mitmproxy/proxy/commands.py index 8f0799005..f95c50685 100644 --- a/mitmproxy/proxy/commands.py +++ b/mitmproxy/proxy/commands.py @@ -39,6 +39,16 @@ class Command: return f"{type(self).__name__}({repr(x)})" +class RequestWakeup(Command): + """ + Request a `Wakeup` event after the specified amount of seconds. + """ + delay: float + + def __init__(self, delay: float): + self.delay = delay + + class ConnectionCommand(Command): """ Commands involving a specific connection diff --git a/mitmproxy/proxy/events.py b/mitmproxy/proxy/events.py index 7cbf26661..5506d34f8 100644 --- a/mitmproxy/proxy/events.py +++ b/mitmproxy/proxy/events.py @@ -121,3 +121,11 @@ class MessageInjected(Event, typing.Generic[T]): """ flow: flow.Flow message: T + + +@dataclass +class Wakeup(CommandCompleted): + """ + Event sent to layers that requested a wakeup using RequestWakeup. + """ + command: commands.RequestWakeup diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index db1a5d571..47703cf71 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -708,6 +708,9 @@ class HttpLayer(layer.Layer): yield from self.event_to_child(self.connections[self.context.client], event) if self.mode is HTTPMode.upstream: self.context.server.via = server_spec.parse_with_mode(self.context.options.mode)[1] + elif isinstance(event, events.Wakeup): + stream = self.command_sources.pop(event.command) + yield from self.event_to_child(stream, event) elif isinstance(event, events.CommandCompleted): stream = self.command_sources.pop(event.command) yield from self.event_to_child(stream, event) @@ -765,7 +768,7 @@ class HttpLayer(layer.Layer): # Streams may yield blocking commands, which ultimately generate CommandCompleted events. # Those need to be routed back to the correct stream, so we need to keep track of that. - if command.blocking: + if command.blocking or isinstance(command, commands.RequestWakeup): self.command_sources[command] = child if isinstance(command, ReceiveHttp): diff --git a/mitmproxy/proxy/layers/http/_http2.py b/mitmproxy/proxy/layers/http/_http2.py index d72b18436..362c6b0a6 100644 --- a/mitmproxy/proxy/layers/http/_http2.py +++ b/mitmproxy/proxy/layers/http/_http2.py @@ -20,9 +20,9 @@ from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolE ResponseEndOfMessage, ResponseHeaders, RequestTrailers, ResponseTrailers, ResponseProtocolError from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error from ._http_h2 import BufferedH2Connection, H2ConnectionLogger -from ...commands import CloseConnection, Log, SendData +from ...commands import CloseConnection, Log, SendData, RequestWakeup from ...context import Context -from ...events import ConnectionClosed, DataReceived, Event, Start +from ...events import ConnectionClosed, DataReceived, Event, Start, Wakeup from ...layer import CommandGenerator from ...utils import expect @@ -252,7 +252,7 @@ class Http2Connection(HttpConnection): self.streams.clear() self._handle_event = self.done # type: ignore - @expect(DataReceived, HttpEvent, ConnectionClosed) + @expect(DataReceived, HttpEvent, ConnectionClosed, Wakeup) def done(self, _) -> CommandGenerator[None]: yield from () @@ -358,6 +358,8 @@ class Http2Client(Http2Connection): """Queue of streams that we haven't sent yet because we have reached MAX_CONCURRENT_STREAMS""" provisional_max_concurrency: Optional[int] = 10 """A provisional currency limit before we get the server's first settings frame.""" + last_activity: float + """Timestamp of when we've last seen network activity on this connection.""" def __init__(self, context: Context): super().__init__(context, context.server) @@ -407,7 +409,30 @@ class Http2Client(Http2Connection): yield from self._handle_event(event) def _handle_event2(self, event: Event) -> CommandGenerator[None]: - if isinstance(event, RequestHeaders): + if isinstance(event, Wakeup): + send_ping_now = ( + # add one second to avoid unnecessary roundtrip, we don't need to be super correct here. + time.time() - self.last_activity + 1 > self.context.options.http2_ping_keepalive + ) + if send_ping_now: + # PING frames MUST contain 8 octets of opaque data in the payload. + # A sender can include any value it chooses and use those octets in any fashion. + self.last_activity = time.time() + self.h2_conn.ping(b"0" * 8) + data = self.h2_conn.data_to_send() + if data is not None: + yield Log(f"Send HTTP/2 keep-alive PING to {human.format_address(self.conn.peername)}") + yield SendData(self.conn, data) + time_until_next_ping = self.context.options.http2_ping_keepalive - (time.time() - self.last_activity) + yield RequestWakeup(time_until_next_ping) + return + + self.last_activity = time.time() + if isinstance(event, Start): + if self.context.options.http2_ping_keepalive > 0: + yield RequestWakeup(self.context.options.http2_ping_keepalive) + yield from super()._handle_event(event) + elif isinstance(event, RequestHeaders): pseudo_headers = [ (b':method', event.request.data.method), (b':scheme', event.request.data.scheme), diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 957091a48..230e67ca7 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -76,11 +76,13 @@ class ConnectionHandler(metaclass=abc.ABCMeta): client: Client max_conns: typing.DefaultDict[Address, asyncio.Semaphore] layer: layer.Layer + wakeup_timer: typing.Set[asyncio.Task] def __init__(self, context: Context) -> None: self.client = context.client self.transports = {} self.max_conns = collections.defaultdict(lambda: asyncio.Semaphore(5)) + self.wakeup_timer = set() # Ask for the first layer right away. # In a reverse proxy scenario, this is necessary as we would otherwise hang @@ -120,6 +122,9 @@ class ConnectionHandler(metaclass=abc.ABCMeta): await asyncio.wait([handler]) watch.cancel() + while self.wakeup_timer: + timer = self.wakeup_timer.pop() + timer.cancel() self.log("client disconnect") self.client.timestamp_end = time.time() @@ -199,6 +204,13 @@ class ConnectionHandler(metaclass=abc.ABCMeta): command.connection.timestamp_end = time.time() await self.handle_hook(server_hooks.ServerDisconnectedHook(hook_data)) + async def wakeup(self, request: commands.RequestWakeup) -> None: + await asyncio.sleep(request.delay) + task = asyncio.current_task() + assert task is not None + self.wakeup_timer.discard(task) + self.server_event(events.Wakeup(request)) + async def handle_connection(self, connection: Connection) -> None: """ Handle a connection for its entire lifetime. @@ -297,6 +309,14 @@ class ConnectionHandler(metaclass=abc.ABCMeta): client=self.client.peername, ) self.transports[command.connection] = ConnectionIO(handler=handler) + elif isinstance(command, commands.RequestWakeup): + task = asyncio_utils.create_task( + self.wakeup(command), + name=f"wakeup timer ({command.delay:.1f}s)", + client=self.client.peername + ) + assert task is not None + self.wakeup_timer.add(task) elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports: pass # The connection has already been closed. elif isinstance(command, commands.SendData): diff --git a/test/mitmproxy/proxy/layers/http/test_http2.py b/test/mitmproxy/proxy/layers/http/test_http2.py index 2044b0bb1..e1464473a 100644 --- a/test/mitmproxy/proxy/layers/http/test_http2.py +++ b/test/mitmproxy/proxy/layers/http/test_http2.py @@ -4,13 +4,14 @@ import h2.settings import hpack import hyperframe.frame import pytest +import time from h2.errors import ErrorCodes from mitmproxy.connection import ConnectionState, Server from mitmproxy.flow import Error from mitmproxy.http import HTTPFlow, Headers, Request from mitmproxy.net.http import status_codes -from mitmproxy.proxy.commands import CloseConnection, Log, OpenConnection, SendData +from mitmproxy.proxy.commands import CloseConnection, Log, OpenConnection, SendData, RequestWakeup from mitmproxy.proxy.context import Context from mitmproxy.proxy.events import ConnectionClosed, DataReceived from mitmproxy.proxy.layers import http @@ -64,8 +65,9 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]: return frames -def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]: +def start_h2_client(tctx: Context, keepalive: int = 0) -> Tuple[Playbook, FrameFactory]: tctx.client.alpn = b"h2" + tctx.options.http2_ping_keepalive = keepalive frame_factory = FrameFactory() playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) @@ -679,6 +681,7 @@ def test_kill_stream(tctx): class TestClient: def test_no_data_on_closed_stream(self, tctx): + tctx.options.http2_ping_keepalive = 0 frame_factory = FrameFactory() req = Request.make("GET", "http://example.com/") resp = { @@ -780,3 +783,64 @@ def test_request_smuggling_te(tctx): << CloseConnection(tctx.client) ) assert b"Connection-specific header field present" in err() + + +def test_request_keepalive(tctx, monkeypatch): + playbook, cff = start_h2_client(tctx, 58) + flow = Placeholder(HTTPFlow) + server = Placeholder(Server) + initial = Placeholder(bytes) + + def advance_time(_): + t = time.time() + monkeypatch.setattr(time, "time", lambda: t + 60) + + assert ( + playbook + >> DataReceived(tctx.client, + cff.build_headers_frame(example_request_headers, flags=["END_STREAM"]).serialize()) + << http.HttpRequestHeadersHook(flow) + >> reply() + << http.HttpRequestHook(flow) + >> reply() + << OpenConnection(server) + >> reply(None, side_effect=make_h2) + << RequestWakeup(58) + << SendData(server, initial) + >> reply(to=-2, side_effect=advance_time) + << SendData(server, b'\x00\x00\x08\x06\x00\x00\x00\x00\x0000000000') # ping frame + << RequestWakeup(58) + ) + + +def test_keepalive_disconnect(tctx, monkeypatch): + playbook, cff = start_h2_client(tctx, 58) + playbook.hooks = False + sff = FrameFactory() + server = Placeholder(Server) + wakeup_command = RequestWakeup(58) + + http_response = ( + sff.build_headers_frame(example_response_headers).serialize() + + sff.build_data_frame(b"", flags=["END_STREAM"]).serialize() + ) + + def advance_time(_): + t = time.time() + monkeypatch.setattr(time, "time", lambda: t + 60) + + assert ( + playbook + >> DataReceived(tctx.client, + cff.build_headers_frame(example_request_headers, flags=["END_STREAM"]).serialize()) + << OpenConnection(server) + >> reply(None, side_effect=make_h2) + << wakeup_command + << SendData(server, Placeholder(bytes)) + >> DataReceived(server, http_response) + << SendData(tctx.client, Placeholder(bytes)) + >> ConnectionClosed(server) + << CloseConnection(server) + >> reply(to=wakeup_command, side_effect=advance_time) + << None + ) diff --git a/test/mitmproxy/proxy/layers/http/test_http_fuzz.py b/test/mitmproxy/proxy/layers/http/test_http_fuzz.py index 9ca689407..c0cc0142b 100644 --- a/test/mitmproxy/proxy/layers/http/test_http_fuzz.py +++ b/test/mitmproxy/proxy/layers/http/test_http_fuzz.py @@ -220,6 +220,7 @@ def h2_frames(draw): def h2_layer(opts): tctx = context.Context(connection.Client(("client", 1234), ("127.0.0.1", 8080), 1605699329), opts) + tctx.options.http2_ping_keepalive = 0 tctx.client.alpn = b"h2" layer = http.HttpLayer(tctx, HTTPMode.regular) diff --git a/test/mitmproxy/proxy/layers/http/test_http_version_interop.py b/test/mitmproxy/proxy/layers/http/test_http_version_interop.py index 923516854..ac214f21c 100644 --- a/test/mitmproxy/proxy/layers/http/test_http_version_interop.py +++ b/test/mitmproxy/proxy/layers/http/test_http_version_interop.py @@ -24,6 +24,7 @@ def event_types(events): def h2_client(tctx: Context) -> Tuple[h2.connection.H2Connection, Playbook]: tctx.client.alpn = b"h2" + tctx.options.http2_ping_keepalive = 0 playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) conn = h2.connection.H2Connection() @@ -89,6 +90,7 @@ def test_h1_to_h2(tctx): """Test HTTP/1 -> HTTP/2 request translation""" server = Placeholder(Server) flow = Placeholder(HTTPFlow) + tctx.options.http2_ping_keepalive = 0 playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) diff --git a/test/mitmproxy/proxy/test_commands.py b/test/mitmproxy/proxy/test_commands.py index ead32ca3a..2ca7c23f0 100644 --- a/test/mitmproxy/proxy/test_commands.py +++ b/test/mitmproxy/proxy/test_commands.py @@ -13,6 +13,7 @@ def tconn() -> connection.Server: def test_dataclasses(tconn): + assert repr(commands.RequestWakeup(58)) assert repr(commands.SendData(tconn, b"foo")) assert repr(commands.OpenConnection(tconn)) assert repr(commands.CloseConnection(tconn)) diff --git a/web/src/js/ducks/_options_gen.ts b/web/src/js/ducks/_options_gen.ts index 87c3f7bd9..26ac9f143 100644 --- a/web/src/js/ducks/_options_gen.ts +++ b/web/src/js/ducks/_options_gen.ts @@ -22,6 +22,7 @@ export interface OptionsState { content_view_lines_cutoff: number export_preserve_original_ip: boolean http2: boolean + http2_ping_keepalive: number ignore_hosts: string[] intercept: string | undefined intercept_active: boolean @@ -110,6 +111,7 @@ export const defaultState: OptionsState = { content_view_lines_cutoff: 512, export_preserve_original_ip: false, http2: true, + http2_ping_keepalive: 58, ignore_hosts: [], intercept: undefined, intercept_active: false,