diff --git a/mitmproxy/proxy2/layers/http/_http2.py b/mitmproxy/proxy2/layers/http/_http2.py index 800cac7fb..92bd81e1b 100644 --- a/mitmproxy/proxy2/layers/http/_http2.py +++ b/mitmproxy/proxy2/layers/http/_http2.py @@ -1,4 +1,5 @@ import time +from enum import Enum from typing import ClassVar, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import h2.connection @@ -23,6 +24,11 @@ from ...events import ConnectionClosed, DataReceived, Event, Start from ...layer import CommandGenerator +class StreamState(Enum): + EXPECTING_HEADERS = 1 + HEADERS_RECEIVED = 2 + + class Http2Connection(HttpConnection): h2_conf: ClassVar[h2.config.H2Configuration] h2_conf_defaults = dict( @@ -34,7 +40,7 @@ class Http2Connection(HttpConnection): # logger=H2ConnectionLogger("server") ) h2_conn: BufferedH2Connection - active_stream_ids: Set[int] + streams: Dict[int, StreamState] """keep track of all active stream ids to send protocol errors on teardown""" ReceiveProtocolError: Type[Union[RequestProtocolError, ResponseProtocolError]] @@ -47,7 +53,7 @@ class Http2Connection(HttpConnection): def __init__(self, context: Context, conn: Connection): super().__init__(context, conn) self.h2_conn = BufferedH2Connection(self.h2_conf) - self.active_stream_ids = set() + self.streams = {} def _handle_event(self, event: Event) -> CommandGenerator[None]: if isinstance(event, Start): @@ -67,7 +73,13 @@ class Http2Connection(HttpConnection): elif isinstance(event, DataReceived): try: - events = self.h2_conn.receive_data(event.data) + try: + events = self.h2_conn.receive_data(event.data) + except ValueError as e: # pragma: no cover + # this should never raise a ValueError, but we triggered one while fuzzing: + # https://github.com/python-hyper/hyper-h2/issues/1231 + # this stays here as defense-in-depth. + raise h2.exceptions.ProtocolError(f"uncaught hyper-h2 error: {e}") from e except h2.exceptions.ProtocolError as e: events = [e] @@ -80,38 +92,47 @@ class Http2Connection(HttpConnection): yield SendData(self.conn, data_to_send) elif isinstance(event, ConnectionClosed): - yield from self._unexpected_close("peer closed connection") + yield from self.close_connection("peer closed connection") else: raise AssertionError(f"Unexpected event: {event!r}") + # noinspection PyArgumentList def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]: """returns true if further processing should be stopped.""" if isinstance(event, h2.events.DataReceived): - if event.stream_id in self.active_stream_ids: - # noinspection PyArgumentList + state = self.streams.get(event.stream_id, None) + if state is StreamState.HEADERS_RECEIVED: yield ReceiveHttp(self.ReceiveData(event.stream_id, event.data)) + elif state is StreamState.EXPECTING_HEADERS: + yield from self.protocol_error(f"Received HTTP/2 data frame, expected headers.") + return True self.h2_conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) elif isinstance(event, h2.events.StreamEnded): - if event.stream_id in self.active_stream_ids: - # noinspection PyArgumentList + state = self.streams.get(event.stream_id, None) + if state is StreamState.HEADERS_RECEIVED: yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id)) - self.active_stream_ids.remove(event.stream_id) + elif state is StreamState.EXPECTING_HEADERS: + raise AssertionError("unreachable") + self.streams.pop(event.stream_id, None) + elif isinstance(event, h2.events.StreamReset): + if event.stream_id in self.streams: + yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, "Stream reset")) elif isinstance(event, h2.exceptions.ProtocolError): - yield from self._unexpected_close(f"HTTP/2 protocol error: {event}") + yield from self.protocol_error(f"HTTP/2 protocol error: {event}") return True elif isinstance(event, h2.events.ConnectionTerminated): - yield from self._unexpected_close(f"HTTP/2 connection closed: {event!r}") + yield from self.close_connection(f"HTTP/2 connection closed: {event!r}") return True - elif isinstance(event, h2.events.StreamReset): - if event.stream_id in self.active_stream_ids: - # noinspection PyArgumentList - yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, "Stream reset")) elif isinstance(event, h2.events.RemoteSettingsChanged): pass elif isinstance(event, h2.events.SettingsAcknowledged): pass elif isinstance(event, h2.events.PriorityUpdated): pass + elif isinstance(event, h2.events.TrailersReceived): + yield Log("Received HTTP/2 trailers, which are currently unimplemented and silently discarded", "error") + elif isinstance(event, h2.events.PushedStreamReceived): + yield Log("Received HTTP/2 push promise, even though we signalled no support.", "error") elif isinstance(event, h2.events.UnknownFrameReceived): # https://http2.github.io/http2-spec/#rfc.section.4.1 # Implementations MUST ignore and discard any frame that has a type that is unknown. @@ -119,11 +140,21 @@ class Http2Connection(HttpConnection): else: raise AssertionError(f"Unexpected event: {event!r}") - def _unexpected_close(self, err: str) -> CommandGenerator[None]: + def protocol_error( + self, + message: str, + error_code: int = h2.errors.ErrorCodes.PROTOCOL_ERROR, + ) -> CommandGenerator[None]: + yield Log(f"{human.format_address(self.conn.peername)}: {message}") + self.h2_conn.close_connection(error_code, message.encode()) + yield SendData(self.conn, self.h2_conn.data_to_send()) + yield from self.close_connection(message) + + def close_connection(self, msg: str) -> CommandGenerator[None]: yield CloseConnection(self.conn) - for stream_id in self.active_stream_ids: + for stream_id in self.streams: # noinspection PyArgumentList - yield ReceiveHttp(self.ReceiveProtocolError(stream_id, err)) + yield ReceiveHttp(self.ReceiveProtocolError(stream_id, msg)) def normalize_h1_headers(headers: List[Tuple[bytes, bytes]], is_client: bool) -> List[Tuple[bytes, bytes]]: @@ -177,10 +208,8 @@ class Http2Server(Http2Connection): try: host, port, method, scheme, authority, path, headers = parse_h2_request_headers(event.headers) except ValueError as e: - yield Log(f"{human.format_address(self.conn.peername)}: {e}") - self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR) - yield SendData(self.conn, self.h2_conn.data_to_send()) - return + yield from self.protocol_error(f"Invalid HTTP/2 request headers: {e}") + return True request = http.HTTPRequest( host=host, port=port, @@ -195,7 +224,7 @@ class Http2Server(Http2Connection): timestamp_start=time.time(), timestamp_end=None, ) - self.active_stream_ids.add(event.stream_id) + self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED yield ReceiveHttp(RequestHeaders(event.stream_id, request)) else: return (yield from super().handle_h2_event(event)) @@ -204,7 +233,7 @@ class Http2Server(Http2Connection): class Http2Client(Http2Connection): h2_conf = h2.config.H2Configuration( **Http2Connection.h2_conf_defaults, - client_side = True, + client_side=True, ) ReceiveProtocolError = ResponseProtocolError @@ -237,14 +266,23 @@ class Http2Client(Http2Connection): event.stream_id, headers, ) - self.active_stream_ids.add(event.stream_id) + self.streams[event.stream_id] = StreamState.EXPECTING_HEADERS yield SendData(self.conn, self.h2_conn.data_to_send()) else: yield from super()._handle_event(event) def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]: if isinstance(event, h2.events.ResponseReceived): - status_code, headers = parse_h2_response_headers(event.headers) + if self.streams.get(event.stream_id, None) is not StreamState.EXPECTING_HEADERS: + yield from self.protocol_error(f"Received unexpected HTTP/2 response.") + return True + + try: + status_code, headers = parse_h2_response_headers(event.headers) + except ValueError as e: + yield from self.protocol_error(f"Invalid HTTP/2 response headers: {e}") + return True + response = http.HTTPResponse( http_version=b"HTTP/2.0", status_code=status_code, @@ -255,7 +293,11 @@ class Http2Client(Http2Connection): timestamp_start=time.time(), timestamp_end=None, ) + self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED yield ReceiveHttp(ResponseHeaders(event.stream_id, response)) + elif isinstance(event, h2.events.RequestReceived): + yield from self.protocol_error(f"HTTP/2 protocol error: received request from server") + return True else: return (yield from super().handle_h2_event(event)) @@ -277,6 +319,7 @@ def split_pseudo_headers(h2_headers: Iterable[Tuple[bytes, bytes]]) -> Tuple[Dic return pseudo_headers, headers + def parse_h2_request_headers( h2_headers: Iterable[Tuple[bytes, bytes]] ) -> Tuple[str, int, bytes, bytes, bytes, bytes, net_http.Headers]: diff --git a/test/mitmproxy/proxy2/layers/http/test_http_fuzz.py b/test/mitmproxy/proxy2/layers/http/test_http_fuzz.py index 2da6bb2bd..ec6b2390c 100644 --- a/test/mitmproxy/proxy2/layers/http/test_http_fuzz.py +++ b/test/mitmproxy/proxy2/layers/http/test_http_fuzz.py @@ -14,6 +14,7 @@ from mitmproxy.proxy2.commands import OpenConnection, SendData from mitmproxy.proxy2.events import DataReceived, Start from mitmproxy.proxy2.layers import http from test.mitmproxy.proxy2.layers.http.hyper_h2_test_helpers import FrameFactory +from test.mitmproxy.proxy2.layers.http.test_http2 import make_h2 from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply settings.register_profile("fast", max_examples=10) @@ -50,7 +51,10 @@ bodies = sampled_from([ ]) -def mutate(draw, data: bytes) -> bytes: +@composite +def mutations(draw, elements): + data = draw(elements) + cut_start = draw(integers(0, len(data))) cut_end = draw(integers(cut_start, len(data))) data = data[:cut_start] + data[cut_end:] @@ -60,40 +64,42 @@ def mutate(draw, data: bytes) -> bytes: return data[:replace_start] + draw(binary()) + data[replace_end:] -def split(draw, data: bytes) -> Iterable[bytes]: +@composite +def chunks(draw, elements): + data = draw(elements) + + chunks = [] a, b = sorted([ draw(integers(0, len(data))), draw(integers(0, len(data))) ]) if a > 0: - yield data[:a] + chunks.append(data[:a]) if a != b: - yield data[a:b] + chunks.append(data[a:b]) if b < len(data): - yield data[b:] + chunks.append(data[b:]) + + return chunks @composite -def fuzz_h1_request(draw): +def h1_requests(draw): request = draw(request_lines) + b"\r\n" request += b"\r\n".join(draw(headers)) request += b"\r\n\r\n" + draw(bodies) - request = mutate(draw, request) - request = list(split(draw, request)) return request @composite -def fuzz_h1_response(draw): +def h2_responses(draw): response = draw(response_lines) + b"\r\n" response += b"\r\n".join(draw(headers)) response += b"\r\n\r\n" + draw(bodies) - response = mutate(draw, response) - response = list(split(draw, response)) return response -@given(fuzz_h1_request()) +@given(chunks(mutations(h1_requests()))) def test_fuzz_h1_request(data): tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080)), opts) @@ -105,7 +111,7 @@ def test_fuzz_h1_request(data): pass -@given(fuzz_h1_response()) +@given(chunks(mutations(h2_responses()))) @example([b'0 OK\r\n\r\n', b'\r\n', b'5\r\n12345\r\n0\r\n\r\n']) def test_fuzz_h1_response(data): tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080)), opts) @@ -127,10 +133,45 @@ h2_flags = sets(sampled_from([ "END_STREAM", "END_HEADERS", ])) +h2_stream_ids = integers(0, 3) +h2_stream_ids_nonzero = integers(1, 3) -def _h2_request_parts(draw): +@composite +def h2_headers(draw): + required_headers = [ + [":path", '/'], + [":scheme", draw(sampled_from(['http', 'https']))], + [":method", draw(sampled_from(['GET', 'POST', 'CONNECT']))], + ] + optional_headers = [ + [":authority", draw(sampled_from(['example.com:443', 'example.com']))], + ["cookie", "foobaz"], + ["host", "example.com"], + ["content-length", "42"], + ] + headers = required_headers + draw(lists(sampled_from(optional_headers), max_size=3)) + + i = draw(integers(0, len(headers))) + p = int(draw(booleans())) + r = draw(text()) + if i > 0: + headers[i - 1][p - 1] = r + return headers + + +@composite +def h2_frames(draw): ff = FrameFactory() + headers1 = ff.build_headers_frame(headers=draw(h2_headers())) + headers1.flags.clear() + headers1.flags |= draw(h2_flags) + headers2 = ff.build_headers_frame(headers=draw(h2_headers()), + depends_on=draw(h2_stream_ids), + stream_weight=draw(integers(0, 255)), + exclusive=draw(booleans())) + headers2.flags.clear() + headers2.flags |= draw(h2_flags) settings = ff.build_settings_frame( settings=draw(dictionaries( keys=sampled_from(SettingCodes), @@ -138,39 +179,32 @@ def _h2_request_parts(draw): max_size=5, )), ack=draw(booleans()) ) - headers = ff.build_headers_frame( - headers=draw(permutations([ - (':authority', draw(sampled_from(['example.com', 'example.com:443', draw(text())]))), - (':path', draw(sampled_from(['/', draw(text())]))), - (':scheme', draw(sampled_from(['http', 'https', draw(text())]))), - (':method', draw(sampled_from(['GET', 'POST', 'CONNECT', draw(text())]))), - ('cookie', draw(text())), - ('host', draw(text())), - ('content-length', draw(text())) - ])) + continuation = ff.build_continuation_frame(header_block=ff.encoder.encode(draw(h2_headers())), flags=draw(h2_flags)) + goaway = ff.build_goaway_frame(draw(h2_stream_ids)) + push_promise = ff.build_push_promise_frame( + stream_id=draw(h2_stream_ids_nonzero), + promised_stream_id=draw(h2_stream_ids), + headers=draw(h2_headers()) ) - headers.flags.clear() - headers.flags |= draw(h2_flags) - data = ff.build_data_frame( + rst = ff.build_rst_stream_frame(draw(h2_stream_ids_nonzero)) + prio = ff.build_priority_frame( + stream_id=draw(h2_stream_ids_nonzero), + weight=draw(integers(0, 255)), + depends_on=draw(h2_stream_ids), + exclusive=draw(booleans()), + ) + data1 = ff.build_data_frame( draw(binary()), draw(h2_flags) ) - window_update = ff.build_window_update_frame(draw(sampled_from([1, 0, 2])), draw(integers(0, 2 ** 32 - 1))) + data2 = ff.build_data_frame( + draw(binary()), draw(h2_flags), stream_id=draw(h2_stream_ids_nonzero) + ) + window_update = ff.build_window_update_frame(draw(h2_stream_ids), draw(integers(0, 2 ** 32 - 1))) - return draw(lists(sampled_from([headers, settings, data, window_update]), min_size=1, max_size=4)) - - -@composite -def h2_request_parts(draw): - return _h2_request_parts(draw) - - -@composite -def h2_request_chunks(draw): - parts = _h2_request_parts(draw) - request = b"".join(x.serialize() for x in parts) - request = mutate(draw, request) - request = list(split(draw, request)) - return request + frames = draw(lists(sampled_from([ + headers1, headers2, settings, continuation, goaway, push_promise, rst, prio, data1, data2, window_update + ]), min_size=1, max_size=11)) + return b"".join(x.serialize() for x in frames) def h2_layer(opts): @@ -185,23 +219,59 @@ def h2_layer(opts): return tctx, layer -@given(h2_request_parts()) -def test_fuzz_h2_request(parts): +def _h2_request(chunks): tctx, layer = h2_layer(opts) - for part in parts: - for _ in layer.handle_event(DataReceived(tctx.client, part.serialize())): + for chunk in chunks: + for _ in layer.handle_event(DataReceived(tctx.client, chunk)): pass -@given(h2_request_chunks()) +@given(chunks(h2_frames())) +@example([b'\x00\x00\x00\x01\x05\x00\x00\x00\x01\x00\x00\x00\x01\x05\x00\x00\x00\x01']) @example([b'\x00\x00\x00\x01\x07\x00\x00\x00\x01A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x86\x82`\x80f\x80\\\x80']) @example([b'\x00\x00\x05\x02\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00']) @example([b'\x00\x00\x13\x01\x04\x00\x00\x00\x01A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x86\x82`\x80f\x80\\\x80']) @example([b'\x00\x00\x12\x01\x04\x00\x00\x00\x01\x84\x86\x82`\x80A\x88/\x91\xd3]\x05\\\x87\xa7\\\x81\x07']) @example([b'\x00\x00\x12\x01\x04\x00\x00\x00\x01\x84\x86\x82`\x80A\x88/\x91\xd3]\x05\\\x87\xa7\\\x81\x07']) -@example([b'\x00\x00\x14\x01\x04\x00\x00\x00\x01A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x86`\x80\x82f\x80\\\x81\x07']) -def test_fuzz_h2_request2(chunks): - tctx, layer = h2_layer(opts) +@example([b'\x00\x00\x14\x01\x04\x00\x00\x00\x01A\x88/\x91\xd3]\x05\\\x87\xa7\x84\x86`\x80\x82f\x80']) +@example([ + b'\x00\x00%\x01\x04\x00\x00\x00\x01A\x8b/\x91\xd3]\x05\\\x87\xa6\xe3M3\x84\x86\x82`\x85\x94\xe7\x8c~\xfff\x88/\x91\xd3]\x05\\\x87\xa7\\\x82h_\x00\x00\x07\x01\x05\x00\x00\x00\x01\xc1\x84\x86\x82\xc0\xbf\xbe']) +def test_fuzz_h2_request_chunks(chunks): + _h2_request(chunks) + + +@given(chunks(mutations(h2_frames()))) +def test_fuzz_h2_request_mutations(chunks): + _h2_request(chunks) + + +def _h2_response(chunks): + tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080)), opts) + playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) + server = Placeholder(context.Server) + assert ( + playbook + >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") + << OpenConnection(server) + >> reply(None, side_effect=make_h2) + << SendData(server, Placeholder()) + ) for chunk in chunks: - for _ in layer.handle_event(DataReceived(tctx.client, chunk)): + for _ in playbook.layer.handle_event(events.DataReceived(server(), chunk)): pass + + +@given(chunks(h2_frames())) +@example([b'\x00\x00\x03\x01\x04\x00\x00\x00\x01\x84\x86\x82', + b'\x00\x00\x07\x05\x04\x00\x00\x00\x01\x00\x00\x00\x00\x84\x86\x82']) +@example([b'\x00\x00\x00\x00\x00\x00\x00\x00\x01']) +@example([b'\x00\x00\x00\x01\x04\x00\x00\x00\x01']) +@example([b'\x00\x00\x07\x05\x04\x00\x00\x00\x01\x00\x00\x00\x02\x84\x86\x82']) +@example([b'\x00\x00\x06\x014\x00\x01\x00\x00\x00\x00\x01@\x80\x81c\x86\x82']) +def test_fuzz_h2_response_chunks(chunks): + _h2_response(chunks) + + +@given(chunks(mutations(h2_frames()))) +def test_fuzz_h2_response_mutations(chunks): + _h2_response(chunks)