[sans-io] h2 client (wip)

This commit is contained in:
Maximilian Hils 2020-06-30 22:47:43 +02:00
parent 0646a4f1ba
commit ffa5a69ebf
3 changed files with 300 additions and 104 deletions

View File

@ -570,7 +570,7 @@ class HttpClient(layer.Layer):
err = yield commands.OpenConnection(self.context.server) err = yield commands.OpenConnection(self.context.server)
if not err: if not err:
if self.context.server.alpn == b"h2": if self.context.server.alpn == b"h2":
raise NotImplementedError child_layer = Http2Client(self.context)
else: else:
child_layer = Http1Client(self.context) child_layer = Http1Client(self.context)
self._handle_event = child_layer.handle_event self._handle_event = child_layer.handle_event

View File

@ -1,5 +1,4 @@
import time from typing import ClassVar, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import ClassVar
import h2.connection import h2.connection
import h2.config import h2.config
@ -8,10 +7,10 @@ import h2.exceptions
import h2.settings import h2.settings
import h2.errors import h2.errors
import h2.utilities import h2.utilities
from hyperframe.frame import SettingsFrame
from mitmproxy import http from mitmproxy import http
from mitmproxy.net import http as net_http from mitmproxy.net import http as net_http
from mitmproxy.net.http import http2
from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._base import HttpConnection, HttpEvent, ReceiveHttp from ._base import HttpConnection, HttpEvent, ReceiveHttp
@ -24,27 +23,26 @@ from ...layer import CommandGenerator
class Http2Connection(HttpConnection): class Http2Connection(HttpConnection):
h2_conf: ClassVar[h2.config.H2Configuration] h2_conf: ClassVar[h2.config.H2Configuration]
h2_conn: BufferedH2Connection h2_conf_defaults = dict(
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
self.h2_conn = BufferedH2Connection(self.h2_conf)
class Http2Server(Http2Connection):
# noinspection PyTypeChecker
h2_conf = h2.config.H2Configuration(
client_side=False,
header_encoding=False, header_encoding=False,
validate_outbound_headers=False, validate_outbound_headers=False,
validate_inbound_headers=False, validate_inbound_headers=False,
normalize_inbound_headers=False, normalize_inbound_headers=False,
normalize_outbound_headers=False, normalize_outbound_headers=False,
logger=H2ConnectionLogger("server") # type: ignore logger=H2ConnectionLogger("server")
) )
h2_conn: BufferedH2Connection
def __init__(self, context: Context): ReceiveProtocolError: Type[Union[RequestProtocolError, ResponseProtocolError]]
super().__init__(context, context.client) SendProtocolError: Type[Union[RequestProtocolError, ResponseProtocolError]]
ReceiveData: Type[Union[RequestData, ResponseData]]
SendData: Type[Union[RequestData, ResponseData]]
ReceiveEndOfMessage: Type[Union[RequestEndOfMessage, ResponseEndOfMessage]]
SendEndOfMessage: Type[Union[RequestEndOfMessage, ResponseEndOfMessage]]
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
self.h2_conn = BufferedH2Connection(self.h2_conf)
def _handle_event(self, event: Event) -> CommandGenerator[None]: def _handle_event(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, Start): if isinstance(event, Start):
@ -52,30 +50,11 @@ class Http2Server(Http2Connection):
yield SendData(self.conn, self.h2_conn.data_to_send()) yield SendData(self.conn, self.h2_conn.data_to_send())
elif isinstance(event, HttpEvent): elif isinstance(event, HttpEvent):
if isinstance(event, ResponseHeaders): if isinstance(event, self.SendData):
headers = (
(b":status", b"%d" % event.response.status_code),
*event.response.headers.fields
)
if event.response.data.http_version != b"HTTP/2":
# HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length),
# which isn't valid HTTP/2. As such we normalize.
headers = h2.utilities.normalize_outbound_headers(
headers,
h2.utilities.HeaderValidationFlags(False, False, True, False)
)
# make sure that this is not just an iterator but an iterable,
# otherwise hyper-h2 will silently drop headers.
headers = list(headers)
self.h2_conn.send_headers(
event.stream_id,
headers,
)
elif isinstance(event, ResponseData):
self.h2_conn.send_data(event.stream_id, event.data) self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, ResponseEndOfMessage): elif isinstance(event, self.SendEndOfMessage):
self.h2_conn.send_data(event.stream_id, b"", end_stream=True) self.h2_conn.send_data(event.stream_id, b"", end_stream=True)
elif isinstance(event, ResponseProtocolError): elif isinstance(event, self.SendProtocolError):
self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR) self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR)
else: else:
raise NotImplementedError(f"Unknown HTTP event: {event}") raise NotImplementedError(f"Unknown HTTP event: {event}")
@ -88,62 +67,225 @@ class Http2Server(Http2Connection):
events = [e] events = [e]
for h2_event in events: for h2_event in events:
if isinstance(h2_event, h2.events.RequestReceived): if (yield from self.handle_h2_event(h2_event)):
headers = net_http.Headers([(k, v) for k, v in h2_event.headers]) return
first_line_format, method, scheme, host, port, path = http2.parse_headers(headers)
headers["Host"] = headers.pop(":authority") # FIXME: temporary workaround data_to_send = self.h2_conn.data_to_send()
if data_to_send:
yield SendData(self.conn, data_to_send)
elif isinstance(event, ConnectionClosed):
yield from self._unexpected_close("peer closed connection")
else:
raise NotImplementedError(f"Unexpected event: {event!r}")
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
"""returns true if further processing should be stopped."""
if isinstance(event, h2.events.DataReceived):
# noinspection PyArgumentList
yield ReceiveHttp(self.ReceiveData(event.stream_id, event.data))
self.h2_conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
elif isinstance(event, h2.events.StreamEnded):
# noinspection PyArgumentList
yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id))
elif isinstance(event, h2.exceptions.ProtocolError):
yield from self._unexpected_close(f"HTTP/2 protocol error: {event}")
return True
elif isinstance(event, h2.events.ConnectionTerminated):
yield from self._unexpected_close(f"HTTP/2 connection closed: {event!r}")
return True
elif isinstance(event, h2.events.StreamReset):
# noinspection PyArgumentList
yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, "Stream reset"))
elif isinstance(event, h2.events.RemoteSettingsChanged):
pass
elif isinstance(event, h2.events.SettingsAcknowledged):
pass
else:
raise NotImplementedError(f"Unknown event: {event!r}")
def _unexpected_close(self, err: str) -> CommandGenerator[None]:
yield CloseConnection(self.conn)
for stream_id, stream in self.h2_conn.streams.items():
if stream.open:
# noinspection PyArgumentList
yield ReceiveHttp(self.ReceiveProtocolError(stream_id, err))
def normalize_h1_headers(headers: List[Tuple[bytes, bytes]], is_client: bool) -> List[Tuple[bytes, bytes]]:
# HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length),
# which isn't valid HTTP/2. As such we normalize.
headers = h2.utilities.normalize_outbound_headers(
headers,
h2.utilities.HeaderValidationFlags(is_client, False, not is_client, False)
)
# make sure that this is not just an iterator but an iterable,
# otherwise hyper-h2 will silently drop headers.
headers = list(headers)
return headers
class Http2Server(Http2Connection):
h2_conf = h2.config.H2Configuration(
client_side=False,
**Http2Connection.h2_conf_defaults
)
ReceiveProtocolError = RequestProtocolError
SendProtocolError = ResponseProtocolError
ReceiveData = RequestData
SendData = ResponseData
ReceiveEndOfMessage = RequestEndOfMessage
SendEndOfMessage = ResponseEndOfMessage
def __init__(self, context: Context):
super().__init__(context, context.client)
def _handle_event(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, ResponseHeaders):
headers = [
(b":status", b"%d" % event.response.status_code),
*event.response.headers.fields
]
if event.response.http_version != b"HTTP/2":
headers = normalize_h1_headers(headers, False)
self.h2_conn.send_headers(
event.stream_id,
headers,
)
yield SendData(self.conn, self.h2_conn.data_to_send())
else:
yield from super()._handle_event(event)
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
if isinstance(event, h2.events.RequestReceived):
method, scheme, host, port, path, headers = parse_h2_request_headers(event.headers)
request = http.HTTPRequest( request = http.HTTPRequest(
first_line_format, "relative",
method, method,
scheme, scheme,
host, host,
port, port,
path, path,
b"HTTP/1.1", # FIXME: Figure out how to smooth h2 <-> h1. b"HTTP/2",
headers, headers,
None, None,
timestamp_start=time.time(),
) )
yield ReceiveHttp(RequestHeaders(h2_event.stream_id, request)) yield ReceiveHttp(RequestHeaders(event.stream_id, request))
elif isinstance(h2_event, h2.events.DataReceived):
yield ReceiveHttp(RequestData(h2_event.stream_id, h2_event.data))
self.h2_conn.acknowledge_received_data(len(h2_event.data), h2_event.stream_id)
elif isinstance(h2_event, h2.events.StreamEnded):
yield ReceiveHttp(RequestEndOfMessage(h2_event.stream_id))
elif isinstance(h2_event, h2.exceptions.ProtocolError):
yield CloseConnection(self.conn)
yield from self._notify_close(f"HTTP/2 protocol error: {h2_event}")
return
elif isinstance(h2_event, h2.events.ConnectionTerminated):
yield CloseConnection(self.conn)
yield from self._notify_close(f"HTTP/2 connection closed: {h2_event!r}")
return
elif isinstance(h2_event, h2.events.StreamReset):
yield ReceiveHttp(RequestProtocolError(h2_event.stream_id, "EOF"))
elif isinstance(h2_event, h2.events.RemoteSettingsChanged):
pass
elif isinstance(h2_event, h2.events.SettingsAcknowledged):
pass
else: else:
raise NotImplementedError(f"Unknown event: {h2_event!r}") return (yield from super().handle_h2_event(event))
data_to_send = self.h2_conn.data_to_send()
if data_to_send: class Http2Client(Http2Connection):
yield SendData(self.conn, data_to_send) h2_conf = h2.config.H2Configuration(
elif isinstance(event, ConnectionClosed): client_side=True,
yield CloseConnection(self.conn) **Http2Connection.h2_conf_defaults
yield from self._notify_close("peer closed connection") )
ReceiveProtocolError = ResponseProtocolError
SendProtocolError = RequestProtocolError
ReceiveData = ResponseData
SendData = RequestData
ReceiveEndOfMessage = ResponseEndOfMessage
SendEndOfMessage = RequestEndOfMessage
def __init__(self, context: Context):
super().__init__(context, context.server)
# Disable HTTP/2 push for now to keep things simple.
self.h2_conn.update_settings({SettingsFrame.ENABLE_PUSH: 0})
def _handle_event(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, RequestHeaders):
headers = [
(b':method', event.request.method),
(b':scheme', event.request.scheme),
(b':path', event.request.path),
*event.request.headers.fields
]
if event.request.http_version == b"HTTP/2":
"""
From the h2 spec:
To ensure that the HTTP/1.1 request line can be reproduced accurately, this pseudo-header field MUST be
omitted when translating from an HTTP/1.1 request that has a request target in origin or asterisk form
(see [RFC7230], Section 5.3). Clients that generate HTTP/2 requests directly SHOULD use the :authority
pseudo-header field instead of the Host header field. An intermediary that converts an HTTP/2 request to
HTTP/1.1 MUST create a Host header field if one is not present in a request by copying the value of the
:authority pseudo-header field.
"""
if headers[3][0].lower() == b"host":
headers[3] = (b":authority", headers[3][1])
else: else:
raise NotImplementedError(f"Unexpected event: {event!r}") headers = normalize_h1_headers(headers, True)
def _notify_close(self, err: str) -> CommandGenerator[None]:
for stream_id, stream in self.h2_conn.streams.items():
if stream.open:
yield ReceiveHttp(RequestProtocolError(stream_id, err))
class Http2Client:
pass # TODO
self.h2_conn.send_headers(
event.stream_id,
headers,
)
yield SendData(self.conn, self.h2_conn.data_to_send())
else:
yield from super()._handle_event(event)
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
if isinstance(event, h2.events.ResponseReceived):
headers = net_http.Headers([(k, v) for k, v in event.headers])
status_code = headers.pop(":status")
response = http.HTTPResponse(
b"HTTP/2",
status_code,
b"",
headers,
None,
)
yield ReceiveHttp(ResponseHeaders(event.stream_id, response))
else:
return (yield from super().handle_h2_event(event))
def parse_h2_request_headers(
h2_headers: Iterable[Tuple[bytes, bytes]]
) -> Tuple[bytes, bytes, Optional[bytes], Optional[int], bytes, net_http.Headers]:
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
pseudo_headers: Dict[bytes, bytes] = {}
i = 0
for i, (header, value) in enumerate(h2_headers):
if header.startswith(b":"):
if header in pseudo_headers:
raise ValueError(f"Duplicate HTTP/2 pseudo headers: {header}")
pseudo_headers[header] = value
else:
# Pseudo-headers must be at the start, we are done here.
break
headers = net_http.Headers(h2_headers[i:])
try:
method: bytes = pseudo_headers.pop(b":method")
scheme: bytes = pseudo_headers.pop(b":scheme") # this raises for HTTP/2 CONNECT requests
path: bytes = pseudo_headers.pop(b":path")
authority: bytes = pseudo_headers.pop(b":authority", None)
except KeyError as e:
raise ValueError(f"Required pseudo header is missing: {e}")
if pseudo_headers:
raise ValueError(f"Unknown pseudo headers: {pseudo_headers}")
host = None
port = None
if authority is not None:
headers.insert(0, b"Host", authority)
host, _, portstr = authority.rpartition(b":") # partition from the right to support IPv6 addresses
if host == b"":
host = portstr
port = 443 if scheme == b'https' else 80
else:
port = int(portstr)
return method, scheme, host, port, path, headers
__all__ = [ __all__ = [

View File

@ -1,5 +1,6 @@
from typing import Callable, List from typing import Callable, List, Tuple
import hpack
import hyperframe.frame import hyperframe.frame
import pytest import pytest
@ -12,12 +13,6 @@ from mitmproxy.proxy2.layers import http
from test.mitmproxy.proxy2.layers.http.hyper_h2_test_helpers import FrameFactory from test.mitmproxy.proxy2.layers.http.hyper_h2_test_helpers import FrameFactory
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply
@pytest.fixture
def frame_factory() -> FrameFactory:
return FrameFactory()
example_request_headers = ( example_request_headers = (
(b':authority', b'example.com'), (b':authority', b'example.com'),
(b':path', b'/'), (b':path', b'/'),
@ -32,6 +27,9 @@ example_response_headers = (
def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]: def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
# swallow preamble
if data.startswith(b"PRI * HTTP/2.0"):
data = data[24:]
frames = [] frames = []
while data: while data:
f, length = hyperframe.frame.Frame.parse_frame_header(data[:9]) f, length = hyperframe.frame.Frame.parse_frame_header(data[:9])
@ -41,8 +39,10 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
return frames return frames
def start_h2(tctx: Context, frame_factory: FrameFactory) -> Playbook:
def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
tctx.client.alpn = b"h2" tctx.client.alpn = b"h2"
frame_factory = FrameFactory()
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
assert ( assert (
@ -51,7 +51,7 @@ def start_h2(tctx: Context, frame_factory: FrameFactory) -> Playbook:
>> DataReceived(tctx.client, frame_factory.preamble()) >> DataReceived(tctx.client, frame_factory.preamble())
>> DataReceived(tctx.client, frame_factory.build_settings_frame({}, ack=True).serialize()) >> DataReceived(tctx.client, frame_factory.build_settings_frame({}, ack=True).serialize())
) )
return playbook return playbook, frame_factory
def make_h2(open_connection: OpenConnection) -> None: def make_h2(open_connection: OpenConnection) -> None:
@ -59,31 +59,28 @@ def make_h2(open_connection: OpenConnection) -> None:
@pytest.mark.parametrize("stream", [True, False]) @pytest.mark.parametrize("stream", [True, False])
def test_http2_client_aborts(tctx, frame_factory, stream): def test_http2_client_aborts(tctx, stream):
"""Test handling of the case where a client aborts during request transmission.""" """Test handling of the case where a client aborts during request transmission."""
server = Placeholder(Server) server = Placeholder(Server)
flow = Placeholder(HTTPFlow) flow = Placeholder(HTTPFlow)
playbook = start_h2(tctx, frame_factory) playbook, cff = start_h2_client(tctx)
def enable_streaming(flow: HTTPFlow): def enable_streaming(flow: HTTPFlow):
flow.request.stream = True flow.request.stream = True
assert ( assert (
playbook playbook
>> DataReceived(tctx.client, frame_factory.build_headers_frame(example_request_headers).serialize()) >> DataReceived(tctx.client, cff.build_headers_frame(example_request_headers).serialize())
<< http.HttpRequestHeadersHook(flow) << http.HttpRequestHeadersHook(flow)
) )
if stream: if stream:
pytest.xfail("h2 client not implemented yet")
assert ( assert (
playbook playbook
>> reply(side_effect=enable_streaming) >> reply(side_effect=enable_streaming)
<< OpenConnection(server) << OpenConnection(server)
>> reply(None, side_effect=make_h2) >> reply(None)
<< SendData(server, b"POST / HTTP/1.1\r\n" << SendData(server, b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n" b"Host: example.com\r\n\r\n")
b"Content-Length: 6\r\n\r\n"
b"abc")
) )
else: else:
assert playbook >> reply() assert playbook >> reply()
@ -100,6 +97,63 @@ def test_http2_client_aborts(tctx, frame_factory, stream):
@pytest.mark.xfail @pytest.mark.xfail
def test_no_normalization(): def test_no_normalization(tctx):
"""Test that we don't normalize headers when we just pass them through.""" """Test that we don't normalize headers when we just pass them through."""
raise NotImplementedError
server = Placeholder(Server)
flow = Placeholder(HTTPFlow)
playbook, cff = start_h2_client(tctx)
request_headers = example_request_headers + (
(b"Should-Not-Be-Capitalized! ", b" :) "),
)
response_headers = example_response_headers + (
(b"Same", b"Here"),
)
initial = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(request_headers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, initial)
)
frames = decode_frames(initial())
assert [type(x) for x in frames] == [
hyperframe.frame.SettingsFrame,
hyperframe.frame.HeadersFrame,
hyperframe.frame.DataFrame
]
assert hpack.hpack.Decoder().decode(frames[1].data, True) == list(request_headers)
sff = FrameFactory()
assert (
playbook
<< SendData(server, sff.build_headers_frame(request_headers, flags=["END_STREAM"]).serialize())
>> DataReceived(server, sff.build_headers_frame(response_headers, flags=["END_STREAM"]).serialize())
<< http.HttpResponseHeadersHook(flow)
>> reply()
<< http.HttpResponseHook(flow)
>> reply()
<< SendData(tctx.client, cff.build_headers_frame(response_headers, flags=["END_STREAM"]).serialize())
)
assert flow().request.headers.fields == request_headers
assert flow().response.headers.fields == response_headers
def start_h2_server(playbook: Playbook) -> FrameFactory:
frame_factory = FrameFactory()
server = Placeholder(Server)
assert (
playbook
>> reply(None, side_effect=make_h2)
<< SendData(server, Placeholder())
)
playbook >> DataReceived(server, frame_factory.build_settings_frame({}, ack=True))
return frame_factory