diff --git a/mitmproxy/proxy2/layers/http/__init__.py b/mitmproxy/proxy2/layers/http/__init__.py index 87d44379e..9810616ff 100644 --- a/mitmproxy/proxy2/layers/http/__init__.py +++ b/mitmproxy/proxy2/layers/http/__init__.py @@ -570,9 +570,9 @@ class HttpClient(layer.Layer): err = yield commands.OpenConnection(self.context.server) if not err: if self.context.server.alpn == b"h2": - raise NotImplementedError + child_layer = Http2Client(self.context) else: child_layer = Http1Client(self.context) - self._handle_event = child_layer.handle_event + self._handle_event = child_layer.handle_event yield from self._handle_event(event) yield RegisterHttpConnection(self.context.server, err) diff --git a/mitmproxy/proxy2/layers/http/_http2.py b/mitmproxy/proxy2/layers/http/_http2.py index 3eca345af..80322c21e 100644 --- a/mitmproxy/proxy2/layers/http/_http2.py +++ b/mitmproxy/proxy2/layers/http/_http2.py @@ -1,5 +1,4 @@ -import time -from typing import ClassVar +from typing import ClassVar, Dict, Iterable, List, Optional, Tuple, Type, Union import h2.connection import h2.config @@ -8,10 +7,10 @@ import h2.exceptions import h2.settings import h2.errors import h2.utilities +from hyperframe.frame import SettingsFrame from mitmproxy import http from mitmproxy.net import http as net_http -from mitmproxy.net.http import http2 from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError from ._base import HttpConnection, HttpEvent, ReceiveHttp @@ -24,27 +23,26 @@ from ...layer import CommandGenerator class Http2Connection(HttpConnection): h2_conf: ClassVar[h2.config.H2Configuration] - h2_conn: BufferedH2Connection - - def __init__(self, context: Context, conn: Connection): - super().__init__(context, conn) - self.h2_conn = BufferedH2Connection(self.h2_conf) - - -class Http2Server(Http2Connection): - # noinspection PyTypeChecker - h2_conf = h2.config.H2Configuration( - client_side=False, + h2_conf_defaults = dict( header_encoding=False, validate_outbound_headers=False, validate_inbound_headers=False, normalize_inbound_headers=False, normalize_outbound_headers=False, - logger=H2ConnectionLogger("server") # type: ignore + logger=H2ConnectionLogger("server") ) + h2_conn: BufferedH2Connection - def __init__(self, context: Context): - super().__init__(context, context.client) + ReceiveProtocolError: Type[Union[RequestProtocolError, ResponseProtocolError]] + SendProtocolError: Type[Union[RequestProtocolError, ResponseProtocolError]] + ReceiveData: Type[Union[RequestData, ResponseData]] + SendData: Type[Union[RequestData, ResponseData]] + ReceiveEndOfMessage: Type[Union[RequestEndOfMessage, ResponseEndOfMessage]] + SendEndOfMessage: Type[Union[RequestEndOfMessage, ResponseEndOfMessage]] + + def __init__(self, context: Context, conn: Connection): + super().__init__(context, conn) + self.h2_conn = BufferedH2Connection(self.h2_conf) def _handle_event(self, event: Event) -> CommandGenerator[None]: if isinstance(event, Start): @@ -52,30 +50,11 @@ class Http2Server(Http2Connection): yield SendData(self.conn, self.h2_conn.data_to_send()) elif isinstance(event, HttpEvent): - if isinstance(event, ResponseHeaders): - headers = ( - (b":status", b"%d" % event.response.status_code), - *event.response.headers.fields - ) - if event.response.data.http_version != b"HTTP/2": - # HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length), - # which isn't valid HTTP/2. As such we normalize. - headers = h2.utilities.normalize_outbound_headers( - headers, - h2.utilities.HeaderValidationFlags(False, False, True, False) - ) - # make sure that this is not just an iterator but an iterable, - # otherwise hyper-h2 will silently drop headers. - headers = list(headers) - self.h2_conn.send_headers( - event.stream_id, - headers, - ) - elif isinstance(event, ResponseData): + if isinstance(event, self.SendData): self.h2_conn.send_data(event.stream_id, event.data) - elif isinstance(event, ResponseEndOfMessage): + elif isinstance(event, self.SendEndOfMessage): self.h2_conn.send_data(event.stream_id, b"", end_stream=True) - elif isinstance(event, ResponseProtocolError): + elif isinstance(event, self.SendProtocolError): self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR) else: raise NotImplementedError(f"Unknown HTTP event: {event}") @@ -88,62 +67,225 @@ class Http2Server(Http2Connection): events = [e] for h2_event in events: - if isinstance(h2_event, h2.events.RequestReceived): - headers = net_http.Headers([(k, v) for k, v in h2_event.headers]) - first_line_format, method, scheme, host, port, path = http2.parse_headers(headers) - headers["Host"] = headers.pop(":authority") # FIXME: temporary workaround - request = http.HTTPRequest( - first_line_format, - method, - scheme, - host, - port, - path, - b"HTTP/1.1", # FIXME: Figure out how to smooth h2 <-> h1. - headers, - None, - timestamp_start=time.time(), - ) - yield ReceiveHttp(RequestHeaders(h2_event.stream_id, request)) - elif isinstance(h2_event, h2.events.DataReceived): - yield ReceiveHttp(RequestData(h2_event.stream_id, h2_event.data)) - self.h2_conn.acknowledge_received_data(len(h2_event.data), h2_event.stream_id) - elif isinstance(h2_event, h2.events.StreamEnded): - yield ReceiveHttp(RequestEndOfMessage(h2_event.stream_id)) - elif isinstance(h2_event, h2.exceptions.ProtocolError): - yield CloseConnection(self.conn) - yield from self._notify_close(f"HTTP/2 protocol error: {h2_event}") + if (yield from self.handle_h2_event(h2_event)): return - elif isinstance(h2_event, h2.events.ConnectionTerminated): - yield CloseConnection(self.conn) - yield from self._notify_close(f"HTTP/2 connection closed: {h2_event!r}") - return - elif isinstance(h2_event, h2.events.StreamReset): - yield ReceiveHttp(RequestProtocolError(h2_event.stream_id, "EOF")) - elif isinstance(h2_event, h2.events.RemoteSettingsChanged): - pass - elif isinstance(h2_event, h2.events.SettingsAcknowledged): - pass - else: - raise NotImplementedError(f"Unknown event: {h2_event!r}") data_to_send = self.h2_conn.data_to_send() if data_to_send: yield SendData(self.conn, data_to_send) + elif isinstance(event, ConnectionClosed): - yield CloseConnection(self.conn) - yield from self._notify_close("peer closed connection") + yield from self._unexpected_close("peer closed connection") else: raise NotImplementedError(f"Unexpected event: {event!r}") - def _notify_close(self, err: str) -> CommandGenerator[None]: + def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]: + """returns true if further processing should be stopped.""" + if isinstance(event, h2.events.DataReceived): + # noinspection PyArgumentList + yield ReceiveHttp(self.ReceiveData(event.stream_id, event.data)) + self.h2_conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) + elif isinstance(event, h2.events.StreamEnded): + # noinspection PyArgumentList + yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id)) + elif isinstance(event, h2.exceptions.ProtocolError): + yield from self._unexpected_close(f"HTTP/2 protocol error: {event}") + return True + elif isinstance(event, h2.events.ConnectionTerminated): + yield from self._unexpected_close(f"HTTP/2 connection closed: {event!r}") + return True + elif isinstance(event, h2.events.StreamReset): + # noinspection PyArgumentList + yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, "Stream reset")) + elif isinstance(event, h2.events.RemoteSettingsChanged): + pass + elif isinstance(event, h2.events.SettingsAcknowledged): + pass + else: + raise NotImplementedError(f"Unknown event: {event!r}") + + def _unexpected_close(self, err: str) -> CommandGenerator[None]: + yield CloseConnection(self.conn) for stream_id, stream in self.h2_conn.streams.items(): if stream.open: - yield ReceiveHttp(RequestProtocolError(stream_id, err)) + # noinspection PyArgumentList + yield ReceiveHttp(self.ReceiveProtocolError(stream_id, err)) -class Http2Client: - pass # TODO +def normalize_h1_headers(headers: List[Tuple[bytes, bytes]], is_client: bool) -> List[Tuple[bytes, bytes]]: + # HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length), + # which isn't valid HTTP/2. As such we normalize. + headers = h2.utilities.normalize_outbound_headers( + headers, + h2.utilities.HeaderValidationFlags(is_client, False, not is_client, False) + ) + # make sure that this is not just an iterator but an iterable, + # otherwise hyper-h2 will silently drop headers. + headers = list(headers) + return headers + + +class Http2Server(Http2Connection): + h2_conf = h2.config.H2Configuration( + client_side=False, + **Http2Connection.h2_conf_defaults + ) + + ReceiveProtocolError = RequestProtocolError + SendProtocolError = ResponseProtocolError + ReceiveData = RequestData + SendData = ResponseData + ReceiveEndOfMessage = RequestEndOfMessage + SendEndOfMessage = ResponseEndOfMessage + + def __init__(self, context: Context): + super().__init__(context, context.client) + + def _handle_event(self, event: Event) -> CommandGenerator[None]: + if isinstance(event, ResponseHeaders): + headers = [ + (b":status", b"%d" % event.response.status_code), + *event.response.headers.fields + ] + if event.response.http_version != b"HTTP/2": + headers = normalize_h1_headers(headers, False) + + self.h2_conn.send_headers( + event.stream_id, + headers, + ) + yield SendData(self.conn, self.h2_conn.data_to_send()) + else: + yield from super()._handle_event(event) + + def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]: + if isinstance(event, h2.events.RequestReceived): + method, scheme, host, port, path, headers = parse_h2_request_headers(event.headers) + request = http.HTTPRequest( + "relative", + method, + scheme, + host, + port, + path, + b"HTTP/2", + headers, + None, + ) + yield ReceiveHttp(RequestHeaders(event.stream_id, request)) + else: + return (yield from super().handle_h2_event(event)) + + +class Http2Client(Http2Connection): + h2_conf = h2.config.H2Configuration( + client_side=True, + **Http2Connection.h2_conf_defaults + ) + + ReceiveProtocolError = ResponseProtocolError + SendProtocolError = RequestProtocolError + ReceiveData = ResponseData + SendData = RequestData + ReceiveEndOfMessage = ResponseEndOfMessage + SendEndOfMessage = RequestEndOfMessage + + def __init__(self, context: Context): + super().__init__(context, context.server) + # Disable HTTP/2 push for now to keep things simple. + self.h2_conn.update_settings({SettingsFrame.ENABLE_PUSH: 0}) + + def _handle_event(self, event: Event) -> CommandGenerator[None]: + if isinstance(event, RequestHeaders): + headers = [ + (b':method', event.request.method), + (b':scheme', event.request.scheme), + (b':path', event.request.path), + *event.request.headers.fields + ] + if event.request.http_version == b"HTTP/2": + """ + From the h2 spec: + + To ensure that the HTTP/1.1 request line can be reproduced accurately, this pseudo-header field MUST be + omitted when translating from an HTTP/1.1 request that has a request target in origin or asterisk form + (see [RFC7230], Section 5.3). Clients that generate HTTP/2 requests directly SHOULD use the :authority + pseudo-header field instead of the Host header field. An intermediary that converts an HTTP/2 request to + HTTP/1.1 MUST create a Host header field if one is not present in a request by copying the value of the + :authority pseudo-header field. + """ + if headers[3][0].lower() == b"host": + headers[3] = (b":authority", headers[3][1]) + else: + headers = normalize_h1_headers(headers, True) + + + + + self.h2_conn.send_headers( + event.stream_id, + headers, + ) + yield SendData(self.conn, self.h2_conn.data_to_send()) + else: + yield from super()._handle_event(event) + + def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]: + if isinstance(event, h2.events.ResponseReceived): + headers = net_http.Headers([(k, v) for k, v in event.headers]) + status_code = headers.pop(":status") + response = http.HTTPResponse( + b"HTTP/2", + status_code, + b"", + headers, + None, + ) + yield ReceiveHttp(ResponseHeaders(event.stream_id, response)) + else: + return (yield from super().handle_h2_event(event)) + + +def parse_h2_request_headers( + h2_headers: Iterable[Tuple[bytes, bytes]] +) -> Tuple[bytes, bytes, Optional[bytes], Optional[int], bytes, net_http.Headers]: + """Split HTTP/2 pseudo-headers from the actual headers and parse them.""" + pseudo_headers: Dict[bytes, bytes] = {} + i = 0 + for i, (header, value) in enumerate(h2_headers): + if header.startswith(b":"): + if header in pseudo_headers: + raise ValueError(f"Duplicate HTTP/2 pseudo headers: {header}") + pseudo_headers[header] = value + else: + # Pseudo-headers must be at the start, we are done here. + break + + headers = net_http.Headers(h2_headers[i:]) + + try: + method: bytes = pseudo_headers.pop(b":method") + scheme: bytes = pseudo_headers.pop(b":scheme") # this raises for HTTP/2 CONNECT requests + path: bytes = pseudo_headers.pop(b":path") + authority: bytes = pseudo_headers.pop(b":authority", None) + except KeyError as e: + raise ValueError(f"Required pseudo header is missing: {e}") + + if pseudo_headers: + raise ValueError(f"Unknown pseudo headers: {pseudo_headers}") + + host = None + port = None + if authority is not None: + headers.insert(0, b"Host", authority) + host, _, portstr = authority.rpartition(b":") # partition from the right to support IPv6 addresses + if host == b"": + host = portstr + port = 443 if scheme == b'https' else 80 + else: + port = int(portstr) + + return method, scheme, host, port, path, headers __all__ = [ diff --git a/test/mitmproxy/proxy2/layers/http/test_http2.py b/test/mitmproxy/proxy2/layers/http/test_http2.py index f3a4ba620..50bc2ffc7 100644 --- a/test/mitmproxy/proxy2/layers/http/test_http2.py +++ b/test/mitmproxy/proxy2/layers/http/test_http2.py @@ -1,5 +1,6 @@ -from typing import Callable, List +from typing import Callable, List, Tuple +import hpack import hyperframe.frame import pytest @@ -12,12 +13,6 @@ from mitmproxy.proxy2.layers import http from test.mitmproxy.proxy2.layers.http.hyper_h2_test_helpers import FrameFactory from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply - -@pytest.fixture -def frame_factory() -> FrameFactory: - return FrameFactory() - - example_request_headers = ( (b':authority', b'example.com'), (b':path', b'/'), @@ -32,6 +27,9 @@ example_response_headers = ( def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]: + # swallow preamble + if data.startswith(b"PRI * HTTP/2.0"): + data = data[24:] frames = [] while data: f, length = hyperframe.frame.Frame.parse_frame_header(data[:9]) @@ -41,8 +39,10 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]: return frames -def start_h2(tctx: Context, frame_factory: FrameFactory) -> Playbook: + +def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]: tctx.client.alpn = b"h2" + frame_factory = FrameFactory() playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) assert ( @@ -51,7 +51,7 @@ def start_h2(tctx: Context, frame_factory: FrameFactory) -> Playbook: >> DataReceived(tctx.client, frame_factory.preamble()) >> DataReceived(tctx.client, frame_factory.build_settings_frame({}, ack=True).serialize()) ) - return playbook + return playbook, frame_factory def make_h2(open_connection: OpenConnection) -> None: @@ -59,31 +59,28 @@ def make_h2(open_connection: OpenConnection) -> None: @pytest.mark.parametrize("stream", [True, False]) -def test_http2_client_aborts(tctx, frame_factory, stream): +def test_http2_client_aborts(tctx, stream): """Test handling of the case where a client aborts during request transmission.""" server = Placeholder(Server) flow = Placeholder(HTTPFlow) - playbook = start_h2(tctx, frame_factory) + playbook, cff = start_h2_client(tctx) def enable_streaming(flow: HTTPFlow): flow.request.stream = True assert ( playbook - >> DataReceived(tctx.client, frame_factory.build_headers_frame(example_request_headers).serialize()) + >> DataReceived(tctx.client, cff.build_headers_frame(example_request_headers).serialize()) << http.HttpRequestHeadersHook(flow) ) if stream: - pytest.xfail("h2 client not implemented yet") assert ( playbook >> reply(side_effect=enable_streaming) << OpenConnection(server) - >> reply(None, side_effect=make_h2) - << SendData(server, b"POST / HTTP/1.1\r\n" - b"Host: example.com\r\n" - b"Content-Length: 6\r\n\r\n" - b"abc") + >> reply(None) + << SendData(server, b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n\r\n") ) else: assert playbook >> reply() @@ -100,6 +97,63 @@ def test_http2_client_aborts(tctx, frame_factory, stream): @pytest.mark.xfail -def test_no_normalization(): +def test_no_normalization(tctx): """Test that we don't normalize headers when we just pass them through.""" - raise NotImplementedError + + server = Placeholder(Server) + flow = Placeholder(HTTPFlow) + playbook, cff = start_h2_client(tctx) + + request_headers = example_request_headers + ( + (b"Should-Not-Be-Capitalized! ", b" :) "), + ) + response_headers = example_response_headers + ( + (b"Same", b"Here"), + ) + + initial = Placeholder(bytes) + assert ( + playbook + >> DataReceived(tctx.client, + cff.build_headers_frame(request_headers, flags=["END_STREAM"]).serialize()) + << http.HttpRequestHeadersHook(flow) + >> reply() + << http.HttpRequestHook(flow) + >> reply() + << OpenConnection(server) + >> reply(None, side_effect=make_h2) + << SendData(server, initial) + ) + frames = decode_frames(initial()) + assert [type(x) for x in frames] == [ + hyperframe.frame.SettingsFrame, + hyperframe.frame.HeadersFrame, + hyperframe.frame.DataFrame + ] + assert hpack.hpack.Decoder().decode(frames[1].data, True) == list(request_headers) + + sff = FrameFactory() + assert ( + playbook + << SendData(server, sff.build_headers_frame(request_headers, flags=["END_STREAM"]).serialize()) + >> DataReceived(server, sff.build_headers_frame(response_headers, flags=["END_STREAM"]).serialize()) + << http.HttpResponseHeadersHook(flow) + >> reply() + << http.HttpResponseHook(flow) + >> reply() + << SendData(tctx.client, cff.build_headers_frame(response_headers, flags=["END_STREAM"]).serialize()) + ) + assert flow().request.headers.fields == request_headers + assert flow().response.headers.fields == response_headers + + +def start_h2_server(playbook: Playbook) -> FrameFactory: + frame_factory = FrameFactory() + server = Placeholder(Server) + assert ( + playbook + >> reply(None, side_effect=make_h2) + << SendData(server, Placeholder()) + ) + playbook >> DataReceived(server, frame_factory.build_settings_frame({}, ack=True)) + return frame_factory