[sans-io] h2 client (wip)

This commit is contained in:
Maximilian Hils 2020-06-30 22:47:43 +02:00
parent 0646a4f1ba
commit ffa5a69ebf
3 changed files with 300 additions and 104 deletions

View File

@ -570,7 +570,7 @@ class HttpClient(layer.Layer):
err = yield commands.OpenConnection(self.context.server)
if not err:
if self.context.server.alpn == b"h2":
raise NotImplementedError
child_layer = Http2Client(self.context)
else:
child_layer = Http1Client(self.context)
self._handle_event = child_layer.handle_event

View File

@ -1,5 +1,4 @@
import time
from typing import ClassVar
from typing import ClassVar, Dict, Iterable, List, Optional, Tuple, Type, Union
import h2.connection
import h2.config
@ -8,10 +7,10 @@ import h2.exceptions
import h2.settings
import h2.errors
import h2.utilities
from hyperframe.frame import SettingsFrame
from mitmproxy import http
from mitmproxy.net import http as net_http
from mitmproxy.net.http import http2
from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._base import HttpConnection, HttpEvent, ReceiveHttp
@ -24,27 +23,26 @@ from ...layer import CommandGenerator
class Http2Connection(HttpConnection):
h2_conf: ClassVar[h2.config.H2Configuration]
h2_conn: BufferedH2Connection
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
self.h2_conn = BufferedH2Connection(self.h2_conf)
class Http2Server(Http2Connection):
# noinspection PyTypeChecker
h2_conf = h2.config.H2Configuration(
client_side=False,
h2_conf_defaults = dict(
header_encoding=False,
validate_outbound_headers=False,
validate_inbound_headers=False,
normalize_inbound_headers=False,
normalize_outbound_headers=False,
logger=H2ConnectionLogger("server") # type: ignore
logger=H2ConnectionLogger("server")
)
h2_conn: BufferedH2Connection
def __init__(self, context: Context):
super().__init__(context, context.client)
ReceiveProtocolError: Type[Union[RequestProtocolError, ResponseProtocolError]]
SendProtocolError: Type[Union[RequestProtocolError, ResponseProtocolError]]
ReceiveData: Type[Union[RequestData, ResponseData]]
SendData: Type[Union[RequestData, ResponseData]]
ReceiveEndOfMessage: Type[Union[RequestEndOfMessage, ResponseEndOfMessage]]
SendEndOfMessage: Type[Union[RequestEndOfMessage, ResponseEndOfMessage]]
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
self.h2_conn = BufferedH2Connection(self.h2_conf)
def _handle_event(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, Start):
@ -52,30 +50,11 @@ class Http2Server(Http2Connection):
yield SendData(self.conn, self.h2_conn.data_to_send())
elif isinstance(event, HttpEvent):
if isinstance(event, ResponseHeaders):
headers = (
(b":status", b"%d" % event.response.status_code),
*event.response.headers.fields
)
if event.response.data.http_version != b"HTTP/2":
# HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length),
# which isn't valid HTTP/2. As such we normalize.
headers = h2.utilities.normalize_outbound_headers(
headers,
h2.utilities.HeaderValidationFlags(False, False, True, False)
)
# make sure that this is not just an iterator but an iterable,
# otherwise hyper-h2 will silently drop headers.
headers = list(headers)
self.h2_conn.send_headers(
event.stream_id,
headers,
)
elif isinstance(event, ResponseData):
if isinstance(event, self.SendData):
self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, ResponseEndOfMessage):
elif isinstance(event, self.SendEndOfMessage):
self.h2_conn.send_data(event.stream_id, b"", end_stream=True)
elif isinstance(event, ResponseProtocolError):
elif isinstance(event, self.SendProtocolError):
self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR)
else:
raise NotImplementedError(f"Unknown HTTP event: {event}")
@ -88,62 +67,225 @@ class Http2Server(Http2Connection):
events = [e]
for h2_event in events:
if isinstance(h2_event, h2.events.RequestReceived):
headers = net_http.Headers([(k, v) for k, v in h2_event.headers])
first_line_format, method, scheme, host, port, path = http2.parse_headers(headers)
headers["Host"] = headers.pop(":authority") # FIXME: temporary workaround
if (yield from self.handle_h2_event(h2_event)):
return
data_to_send = self.h2_conn.data_to_send()
if data_to_send:
yield SendData(self.conn, data_to_send)
elif isinstance(event, ConnectionClosed):
yield from self._unexpected_close("peer closed connection")
else:
raise NotImplementedError(f"Unexpected event: {event!r}")
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
"""returns true if further processing should be stopped."""
if isinstance(event, h2.events.DataReceived):
# noinspection PyArgumentList
yield ReceiveHttp(self.ReceiveData(event.stream_id, event.data))
self.h2_conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
elif isinstance(event, h2.events.StreamEnded):
# noinspection PyArgumentList
yield ReceiveHttp(self.ReceiveEndOfMessage(event.stream_id))
elif isinstance(event, h2.exceptions.ProtocolError):
yield from self._unexpected_close(f"HTTP/2 protocol error: {event}")
return True
elif isinstance(event, h2.events.ConnectionTerminated):
yield from self._unexpected_close(f"HTTP/2 connection closed: {event!r}")
return True
elif isinstance(event, h2.events.StreamReset):
# noinspection PyArgumentList
yield ReceiveHttp(self.ReceiveProtocolError(event.stream_id, "Stream reset"))
elif isinstance(event, h2.events.RemoteSettingsChanged):
pass
elif isinstance(event, h2.events.SettingsAcknowledged):
pass
else:
raise NotImplementedError(f"Unknown event: {event!r}")
def _unexpected_close(self, err: str) -> CommandGenerator[None]:
yield CloseConnection(self.conn)
for stream_id, stream in self.h2_conn.streams.items():
if stream.open:
# noinspection PyArgumentList
yield ReceiveHttp(self.ReceiveProtocolError(stream_id, err))
def normalize_h1_headers(headers: List[Tuple[bytes, bytes]], is_client: bool) -> List[Tuple[bytes, bytes]]:
# HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length),
# which isn't valid HTTP/2. As such we normalize.
headers = h2.utilities.normalize_outbound_headers(
headers,
h2.utilities.HeaderValidationFlags(is_client, False, not is_client, False)
)
# make sure that this is not just an iterator but an iterable,
# otherwise hyper-h2 will silently drop headers.
headers = list(headers)
return headers
class Http2Server(Http2Connection):
h2_conf = h2.config.H2Configuration(
client_side=False,
**Http2Connection.h2_conf_defaults
)
ReceiveProtocolError = RequestProtocolError
SendProtocolError = ResponseProtocolError
ReceiveData = RequestData
SendData = ResponseData
ReceiveEndOfMessage = RequestEndOfMessage
SendEndOfMessage = ResponseEndOfMessage
def __init__(self, context: Context):
super().__init__(context, context.client)
def _handle_event(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, ResponseHeaders):
headers = [
(b":status", b"%d" % event.response.status_code),
*event.response.headers.fields
]
if event.response.http_version != b"HTTP/2":
headers = normalize_h1_headers(headers, False)
self.h2_conn.send_headers(
event.stream_id,
headers,
)
yield SendData(self.conn, self.h2_conn.data_to_send())
else:
yield from super()._handle_event(event)
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
if isinstance(event, h2.events.RequestReceived):
method, scheme, host, port, path, headers = parse_h2_request_headers(event.headers)
request = http.HTTPRequest(
first_line_format,
"relative",
method,
scheme,
host,
port,
path,
b"HTTP/1.1", # FIXME: Figure out how to smooth h2 <-> h1.
b"HTTP/2",
headers,
None,
timestamp_start=time.time(),
)
yield ReceiveHttp(RequestHeaders(h2_event.stream_id, request))
elif isinstance(h2_event, h2.events.DataReceived):
yield ReceiveHttp(RequestData(h2_event.stream_id, h2_event.data))
self.h2_conn.acknowledge_received_data(len(h2_event.data), h2_event.stream_id)
elif isinstance(h2_event, h2.events.StreamEnded):
yield ReceiveHttp(RequestEndOfMessage(h2_event.stream_id))
elif isinstance(h2_event, h2.exceptions.ProtocolError):
yield CloseConnection(self.conn)
yield from self._notify_close(f"HTTP/2 protocol error: {h2_event}")
return
elif isinstance(h2_event, h2.events.ConnectionTerminated):
yield CloseConnection(self.conn)
yield from self._notify_close(f"HTTP/2 connection closed: {h2_event!r}")
return
elif isinstance(h2_event, h2.events.StreamReset):
yield ReceiveHttp(RequestProtocolError(h2_event.stream_id, "EOF"))
elif isinstance(h2_event, h2.events.RemoteSettingsChanged):
pass
elif isinstance(h2_event, h2.events.SettingsAcknowledged):
pass
yield ReceiveHttp(RequestHeaders(event.stream_id, request))
else:
raise NotImplementedError(f"Unknown event: {h2_event!r}")
return (yield from super().handle_h2_event(event))
data_to_send = self.h2_conn.data_to_send()
if data_to_send:
yield SendData(self.conn, data_to_send)
elif isinstance(event, ConnectionClosed):
yield CloseConnection(self.conn)
yield from self._notify_close("peer closed connection")
class Http2Client(Http2Connection):
h2_conf = h2.config.H2Configuration(
client_side=True,
**Http2Connection.h2_conf_defaults
)
ReceiveProtocolError = ResponseProtocolError
SendProtocolError = RequestProtocolError
ReceiveData = ResponseData
SendData = RequestData
ReceiveEndOfMessage = ResponseEndOfMessage
SendEndOfMessage = RequestEndOfMessage
def __init__(self, context: Context):
super().__init__(context, context.server)
# Disable HTTP/2 push for now to keep things simple.
self.h2_conn.update_settings({SettingsFrame.ENABLE_PUSH: 0})
def _handle_event(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, RequestHeaders):
headers = [
(b':method', event.request.method),
(b':scheme', event.request.scheme),
(b':path', event.request.path),
*event.request.headers.fields
]
if event.request.http_version == b"HTTP/2":
"""
From the h2 spec:
To ensure that the HTTP/1.1 request line can be reproduced accurately, this pseudo-header field MUST be
omitted when translating from an HTTP/1.1 request that has a request target in origin or asterisk form
(see [RFC7230], Section 5.3). Clients that generate HTTP/2 requests directly SHOULD use the :authority
pseudo-header field instead of the Host header field. An intermediary that converts an HTTP/2 request to
HTTP/1.1 MUST create a Host header field if one is not present in a request by copying the value of the
:authority pseudo-header field.
"""
if headers[3][0].lower() == b"host":
headers[3] = (b":authority", headers[3][1])
else:
raise NotImplementedError(f"Unexpected event: {event!r}")
def _notify_close(self, err: str) -> CommandGenerator[None]:
for stream_id, stream in self.h2_conn.streams.items():
if stream.open:
yield ReceiveHttp(RequestProtocolError(stream_id, err))
headers = normalize_h1_headers(headers, True)
class Http2Client:
pass # TODO
self.h2_conn.send_headers(
event.stream_id,
headers,
)
yield SendData(self.conn, self.h2_conn.data_to_send())
else:
yield from super()._handle_event(event)
def handle_h2_event(self, event: h2.events.Event) -> CommandGenerator[bool]:
if isinstance(event, h2.events.ResponseReceived):
headers = net_http.Headers([(k, v) for k, v in event.headers])
status_code = headers.pop(":status")
response = http.HTTPResponse(
b"HTTP/2",
status_code,
b"",
headers,
None,
)
yield ReceiveHttp(ResponseHeaders(event.stream_id, response))
else:
return (yield from super().handle_h2_event(event))
def parse_h2_request_headers(
h2_headers: Iterable[Tuple[bytes, bytes]]
) -> Tuple[bytes, bytes, Optional[bytes], Optional[int], bytes, net_http.Headers]:
"""Split HTTP/2 pseudo-headers from the actual headers and parse them."""
pseudo_headers: Dict[bytes, bytes] = {}
i = 0
for i, (header, value) in enumerate(h2_headers):
if header.startswith(b":"):
if header in pseudo_headers:
raise ValueError(f"Duplicate HTTP/2 pseudo headers: {header}")
pseudo_headers[header] = value
else:
# Pseudo-headers must be at the start, we are done here.
break
headers = net_http.Headers(h2_headers[i:])
try:
method: bytes = pseudo_headers.pop(b":method")
scheme: bytes = pseudo_headers.pop(b":scheme") # this raises for HTTP/2 CONNECT requests
path: bytes = pseudo_headers.pop(b":path")
authority: bytes = pseudo_headers.pop(b":authority", None)
except KeyError as e:
raise ValueError(f"Required pseudo header is missing: {e}")
if pseudo_headers:
raise ValueError(f"Unknown pseudo headers: {pseudo_headers}")
host = None
port = None
if authority is not None:
headers.insert(0, b"Host", authority)
host, _, portstr = authority.rpartition(b":") # partition from the right to support IPv6 addresses
if host == b"":
host = portstr
port = 443 if scheme == b'https' else 80
else:
port = int(portstr)
return method, scheme, host, port, path, headers
__all__ = [

View File

@ -1,5 +1,6 @@
from typing import Callable, List
from typing import Callable, List, Tuple
import hpack
import hyperframe.frame
import pytest
@ -12,12 +13,6 @@ from mitmproxy.proxy2.layers import http
from test.mitmproxy.proxy2.layers.http.hyper_h2_test_helpers import FrameFactory
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply
@pytest.fixture
def frame_factory() -> FrameFactory:
return FrameFactory()
example_request_headers = (
(b':authority', b'example.com'),
(b':path', b'/'),
@ -32,6 +27,9 @@ example_response_headers = (
def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
# swallow preamble
if data.startswith(b"PRI * HTTP/2.0"):
data = data[24:]
frames = []
while data:
f, length = hyperframe.frame.Frame.parse_frame_header(data[:9])
@ -41,8 +39,10 @@ def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
return frames
def start_h2(tctx: Context, frame_factory: FrameFactory) -> Playbook:
def start_h2_client(tctx: Context) -> Tuple[Playbook, FrameFactory]:
tctx.client.alpn = b"h2"
frame_factory = FrameFactory()
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
assert (
@ -51,7 +51,7 @@ def start_h2(tctx: Context, frame_factory: FrameFactory) -> Playbook:
>> DataReceived(tctx.client, frame_factory.preamble())
>> DataReceived(tctx.client, frame_factory.build_settings_frame({}, ack=True).serialize())
)
return playbook
return playbook, frame_factory
def make_h2(open_connection: OpenConnection) -> None:
@ -59,31 +59,28 @@ def make_h2(open_connection: OpenConnection) -> None:
@pytest.mark.parametrize("stream", [True, False])
def test_http2_client_aborts(tctx, frame_factory, stream):
def test_http2_client_aborts(tctx, stream):
"""Test handling of the case where a client aborts during request transmission."""
server = Placeholder(Server)
flow = Placeholder(HTTPFlow)
playbook = start_h2(tctx, frame_factory)
playbook, cff = start_h2_client(tctx)
def enable_streaming(flow: HTTPFlow):
flow.request.stream = True
assert (
playbook
>> DataReceived(tctx.client, frame_factory.build_headers_frame(example_request_headers).serialize())
>> DataReceived(tctx.client, cff.build_headers_frame(example_request_headers).serialize())
<< http.HttpRequestHeadersHook(flow)
)
if stream:
pytest.xfail("h2 client not implemented yet")
assert (
playbook
>> reply(side_effect=enable_streaming)
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 6\r\n\r\n"
b"abc")
>> reply(None)
<< SendData(server, b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n\r\n")
)
else:
assert playbook >> reply()
@ -100,6 +97,63 @@ def test_http2_client_aborts(tctx, frame_factory, stream):
@pytest.mark.xfail
def test_no_normalization():
def test_no_normalization(tctx):
"""Test that we don't normalize headers when we just pass them through."""
raise NotImplementedError
server = Placeholder(Server)
flow = Placeholder(HTTPFlow)
playbook, cff = start_h2_client(tctx)
request_headers = example_request_headers + (
(b"Should-Not-Be-Capitalized! ", b" :) "),
)
response_headers = example_response_headers + (
(b"Same", b"Here"),
)
initial = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(request_headers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, initial)
)
frames = decode_frames(initial())
assert [type(x) for x in frames] == [
hyperframe.frame.SettingsFrame,
hyperframe.frame.HeadersFrame,
hyperframe.frame.DataFrame
]
assert hpack.hpack.Decoder().decode(frames[1].data, True) == list(request_headers)
sff = FrameFactory()
assert (
playbook
<< SendData(server, sff.build_headers_frame(request_headers, flags=["END_STREAM"]).serialize())
>> DataReceived(server, sff.build_headers_frame(response_headers, flags=["END_STREAM"]).serialize())
<< http.HttpResponseHeadersHook(flow)
>> reply()
<< http.HttpResponseHook(flow)
>> reply()
<< SendData(tctx.client, cff.build_headers_frame(response_headers, flags=["END_STREAM"]).serialize())
)
assert flow().request.headers.fields == request_headers
assert flow().response.headers.fields == response_headers
def start_h2_server(playbook: Playbook) -> FrameFactory:
frame_factory = FrameFactory()
server = Placeholder(Server)
assert (
playbook
>> reply(None, side_effect=make_h2)
<< SendData(server, Placeholder())
)
playbook >> DataReceived(server, frame_factory.build_settings_frame({}, ack=True))
return frame_factory