use request/response hooks for trailers

This commit is contained in:
Maximilian Hils 2021-03-15 17:12:10 +01:00
parent ed7067d36d
commit 12e4785d44
7 changed files with 129 additions and 121 deletions

View File

@ -17,10 +17,10 @@ from mitmproxy.proxy.utils import expect
from mitmproxy.utils import human
from mitmproxy.websocket import WebSocketData
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, RequestTrailers, \
ResponseTrailers, ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, RequestTrailers, \
ResponseData, ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError, ResponseTrailers
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
HttpResponseHook, HttpRequestTrailersHook, HttpResponseTrailersHook
HttpResponseHook
from ._http1 import Http1Client, Http1Server
from ._http2 import Http2Client, Http2Server
from ...context import Context
@ -211,7 +211,6 @@ class HttpStream(layer.Layer):
def start_request_stream(self) -> layer.CommandGenerator[None]:
if self.flow.response:
raise NotImplementedError("Can't set a response and enable streaming at the same time.")
yield HttpRequestHook(self.flow)
ok = yield from self.make_server_connection()
if not ok:
return
@ -227,11 +226,12 @@ class HttpStream(layer.Layer):
if callable(self.flow.request.stream):
event.data = self.flow.request.stream(event.data)
elif isinstance(event, RequestTrailers):
assert self.flow.request
self.flow.request.trailers = event.trailers
yield HttpRequestTrailersHook(self.flow)
# we don't do anything further here, we wait for RequestEndOfMessage first to trigger the request hook.
return
elif isinstance(event, RequestEndOfMessage):
self.flow.request.timestamp_end = time.time()
yield HttpRequestHook(self.flow)
self.client_state = self.state_done
# edge case found while fuzzing:
@ -245,6 +245,8 @@ class HttpStream(layer.Layer):
if isinstance(evt, ResponseProtocolError):
return
if self.flow.request.trailers:
# we've delayed sending trailers until after `request` has been triggered.
assert isinstance(event, RequestEndOfMessage)
yield SendHttp(RequestTrailers(self.stream_id, self.flow.request.trailers), self.context.server)
yield SendHttp(event, self.context.server)
@ -255,7 +257,6 @@ class HttpStream(layer.Layer):
elif isinstance(event, RequestTrailers):
assert self.flow.request
self.flow.request.trailers = event.trailers
yield HttpRequestTrailersHook(self.flow)
elif isinstance(event, RequestEndOfMessage):
self.flow.request.timestamp_end = time.time()
self.flow.request.data.content = self.request_body_buf
@ -309,7 +310,7 @@ class HttpStream(layer.Layer):
elif isinstance(event, ResponseTrailers):
assert self.flow.response
self.flow.response.trailers = event.trailers
yield HttpResponseTrailersHook(self.flow)
# will be sent in send_response() after the response hook.
elif isinstance(event, ResponseEndOfMessage):
yield from self.send_response(already_streamed=True)
@ -320,7 +321,6 @@ class HttpStream(layer.Layer):
elif isinstance(event, ResponseTrailers):
assert self.flow.response
self.flow.response.trailers = event.trailers
yield HttpResponseTrailersHook(self.flow)
elif isinstance(event, ResponseEndOfMessage):
assert self.flow.response
self.flow.response.data.content = self.response_body_buf
@ -356,9 +356,9 @@ class HttpStream(layer.Layer):
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not content), self.context.client)
if content:
yield SendHttp(ResponseData(self.stream_id, content), self.context.client)
if self.flow.response.trailers:
yield SendHttp(ResponseTrailers(self.stream_id, self.flow.response.trailers), self.context.client)
if self.flow.response.trailers:
yield SendHttp(ResponseTrailers(self.stream_id, self.flow.response.trailers), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
if self.flow.response.status_code == 101:

View File

@ -49,11 +49,19 @@ class ResponseData(HttpEvent):
class RequestTrailers(HttpEvent):
trailers: http.Headers
def __init__(self, stream_id: int, trailers: http.Headers):
self.stream_id = stream_id
self.trailers = trailers
@dataclass
class ResponseTrailers(HttpEvent):
trailers: http.Headers
def __init__(self, stream_id: int, trailers: http.Headers):
self.stream_id = stream_id
self.trailers = trailers
@dataclass
class RequestEndOfMessage(HttpEvent):

View File

@ -18,8 +18,10 @@ class HttpRequestHook(commands.StartHook):
"""
The full HTTP request has been read.
Note: This event fires immediately after requestheaders if the request body is streamed.
This ensures that requestheaders -> request -> responseheaders -> response happen in that order.
Note: If request streaming is active, this event fires after the entire body has been streamed.
HTTP trailers, if present, have not been transmitted to the server yet and can still be modified.
Enabling streaming may cause unexpected event sequences: For example, `response` may now occur
before `request` because the server replied with "413 Payload Too Large" during upload.
"""
name = "request"
flow: http.HTTPFlow
@ -28,7 +30,7 @@ class HttpRequestHook(commands.StartHook):
@dataclass
class HttpResponseHeadersHook(commands.StartHook):
"""
The full HTTP response has been read.
HTTP response headers were successfully read. At this point, the body is empty.
"""
name = "responseheaders"
flow: http.HTTPFlow
@ -37,42 +39,15 @@ class HttpResponseHeadersHook(commands.StartHook):
@dataclass
class HttpResponseHook(commands.StartHook):
"""
HTTP response headers were successfully read. At this point, the body is empty.
The full HTTP response has been read.
Note: If response streaming is active, this event fires after the entire body has been streamed.
HTTP trailers, if present, have not been transmitted to the client yet and can still be modified.
"""
name = "response"
flow: http.HTTPFlow
@dataclass
class HttpRequestTrailersHook(commands.StartHook):
"""
The HTTP request trailers has been read.
HTTP trailers are a rarely-used feature in the HTTP specification
which allows peers to send additional headers after the message body.
This is useful for metadata that is dynamically generated while
the message body is sent, for example a digital signature
or post-processing status.
"""
name = "requesttrailers"
flow: http.HTTPFlow
@dataclass
class HttpResponseTrailersHook(commands.StartHook):
"""
The HTTP response trailers has been read.
HTTP trailers are a rarely-used feature in the HTTP specification
which allows peers to send additional headers after the message body.
This is useful for metadata that is dynamically generated while
the message body is sent, for example a digital signature
or post-processing status.
"""
name = "responsetrailers"
flow: http.HTTPFlow
@dataclass
class HttpErrorHook(commands.StartHook):
"""

View File

@ -51,6 +51,7 @@ class Http2Connection(HttpConnection):
ReceiveProtocolError: Type[Union[RequestProtocolError, ResponseProtocolError]]
ReceiveData: Type[Union[RequestData, ResponseData]]
ReceiveTrailers: Type[Union[RequestTrailers, ResponseTrailers]]
ReceiveEndOfMessage: Type[Union[RequestEndOfMessage, ResponseEndOfMessage]]
def __init__(self, context: Context, conn: Connection):
@ -175,6 +176,9 @@ class Http2Connection(HttpConnection):
yield from self.protocol_error(f"Received HTTP/2 data frame, expected headers.")
return True
self.h2_conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id)
elif isinstance(event, h2.events.TrailersReceived):
trailers = http.Headers(event.headers)
yield ReceiveHttp(self.ReceiveTrailers(event.stream_id, trailers))
elif isinstance(event, h2.events.StreamEnded):
state = self.streams.get(event.stream_id, None)
if state is StreamState.HEADERS_RECEIVED:
@ -219,8 +223,6 @@ class Http2Connection(HttpConnection):
pass
elif isinstance(event, h2.events.PingAckReceived):
pass
elif isinstance(event, h2.events.TrailersReceived):
pass
elif isinstance(event, h2.events.PushedStreamReceived):
yield Log("Received HTTP/2 push promise, even though we signalled no support.", "error")
elif isinstance(event, h2.events.UnknownFrameReceived):
@ -278,6 +280,7 @@ class Http2Server(Http2Connection):
ReceiveProtocolError = RequestProtocolError
ReceiveData = RequestData
ReceiveTrailers = RequestTrailers
ReceiveEndOfMessage = RequestEndOfMessage
def __init__(self, context: Context):
@ -326,10 +329,6 @@ class Http2Server(Http2Connection):
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
yield ReceiveHttp(RequestHeaders(event.stream_id, request, end_stream=bool(event.stream_ended)))
return False
elif isinstance(event, h2.events.TrailersReceived):
trailers = http.Headers(event.headers)
yield ReceiveHttp(RequestTrailers(event.stream_id, trailers))
return False
else:
return (yield from super().handle_h2_event(event))
@ -346,6 +345,7 @@ class Http2Client(Http2Connection):
ReceiveProtocolError = ResponseProtocolError
ReceiveData = ResponseData
ReceiveTrailers = ResponseTrailers
ReceiveEndOfMessage = ResponseEndOfMessage
our_stream_id: Dict[int, int]
@ -456,10 +456,6 @@ class Http2Client(Http2Connection):
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
yield ReceiveHttp(ResponseHeaders(event.stream_id, response, bool(event.stream_ended)))
return False
elif isinstance(event, h2.events.TrailersReceived):
trailers = http.Headers(event.headers)
yield ReceiveHttp(ResponseTrailers(event.stream_id, trailers))
return False
elif isinstance(event, h2.events.RequestReceived):
yield from self.protocol_error(f"HTTP/2 protocol error: received request from server")
return True

View File

@ -316,8 +316,6 @@ def test_request_streaming(tctx, response):
b"abc")
<< http.HttpRequestHeadersHook(flow)
>> reply(side_effect=enable_streaming)
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"POST / HTTP/1.1\r\n"
@ -330,6 +328,8 @@ def test_request_streaming(tctx, response):
playbook
>> DataReceived(tctx.client, b"def")
<< SendData(server, b"DEF")
<< http.HttpRequestHook(flow)
>> reply()
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
<< http.HttpResponseHeadersHook(flow)
>> reply()
@ -350,7 +350,9 @@ def test_request_streaming(tctx, response):
>> reply()
<< SendData(tctx.client, b"HTTP/1.1 413 Request Entity Too Large\r\nContent-Length: 0\r\n\r\n")
>> DataReceived(tctx.client, b"def")
<< SendData(server, b"DEF") # Important: no request hook here!
<< SendData(server, b"DEF")
<< http.HttpRequestHook(flow)
>> reply()
)
elif response == "early close":
assert (
@ -705,8 +707,6 @@ def test_http_client_aborts(tctx, stream):
assert (
playbook
>> reply(side_effect=enable_streaming)
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"POST / HTTP/1.1\r\n"

View File

@ -6,16 +6,16 @@ import hyperframe.frame
import pytest
from h2.errors import ErrorCodes
from mitmproxy.connection import ConnectionState, Server
from mitmproxy.flow import Error
from mitmproxy.http import HTTPFlow, Headers, Request
from mitmproxy.net.http import status_codes
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.layers.http import HTTPMode
from mitmproxy.proxy.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.connection import Server
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.events import ConnectionClosed, DataReceived
from mitmproxy.proxy.layers import http
from mitmproxy.proxy.layers.http._http2 import split_pseudo_headers, Http2Client
from mitmproxy.proxy.layers.http import HTTPMode
from mitmproxy.proxy.layers.http._http2 import Http2Client, split_pseudo_headers
from test.mitmproxy.proxy.layers.http.hyper_h2_test_helpers import FrameFactory
from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply
@ -41,6 +41,16 @@ example_response_trailers = (
)
@pytest.fixture
def open_h2_server_conn():
# this is a bit fake here (port 80, with alpn, but no tls - c'mon),
# but we don't want to pollute our tests with TLS handshakes.
s = Server(("example.com", 80))
s.state = ConnectionState.OPEN
s.alpn = b"h2"
return s
def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
# swallow preamble
if data.startswith(b"PRI * HTTP/2.0"):
@ -113,73 +123,92 @@ def test_simple(tctx):
assert flow().response.text == "Hello, World!"
def test_response_trailers(tctx):
@pytest.mark.parametrize("stream", ["stream", ""])
def test_response_trailers(tctx: Context, open_h2_server_conn: Server, stream):
playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow)
server = Placeholder(Server)
initial = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, initial)
)
frames = decode_frames(initial())
assert [type(x) for x in frames] == [
hyperframe.frame.SettingsFrame,
hyperframe.frame.HeadersFrame,
]
tctx.server = open_h2_server_conn
sff = FrameFactory()
def enable_streaming(flow: HTTPFlow):
flow.response.stream = bool(stream)
flow = Placeholder(HTTPFlow)
(
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< SendData(tctx.server, Placeholder(bytes))
# a conforming h2 server would send settings first, we disregard this for now.
>> DataReceived(tctx.server, sff.build_headers_frame(example_response_headers).serialize() +
sff.build_data_frame(b"Hello, World!").serialize())
<< http.HttpResponseHeadersHook(flow)
>> reply(side_effect=enable_streaming)
)
if stream:
playbook << SendData(
tctx.client,
cff.build_headers_frame(example_response_headers).serialize() +
cff.build_data_frame(b"Hello, World!").serialize()
)
assert (
playbook
>> DataReceived(tctx.server, sff.build_headers_frame(example_response_trailers, flags=["END_STREAM"]).serialize())
<< http.HttpResponseHook(flow)
)
assert flow().response.trailers
del flow().response.trailers["resp-trailer-a"]
if stream:
assert (
playbook
# a conforming h2 server would send settings first, we disregard this for now.
>> DataReceived(server, sff.build_headers_frame(example_response_headers).serialize())
<< http.HttpResponseHeadersHook(flow)
>> reply()
>> DataReceived(server, sff.build_data_frame(b"Hello, World!").serialize())
>> DataReceived(server, sff.build_headers_frame(example_response_trailers, flags=["END_STREAM"]).serialize())
<< http.HttpResponseTrailersHook(flow)
>> reply()
<< http.HttpResponseHook(flow)
<< SendData(tctx.client,
cff.build_headers_frame(example_response_trailers[1:], flags=["END_STREAM"]).serialize())
)
else:
assert (
playbook
>> reply()
<< SendData(tctx.client,
cff.build_headers_frame(example_response_headers).serialize() +
cff.build_data_frame(b"Hello, World!").serialize() +
cff.build_headers_frame(example_response_trailers, flags=["END_STREAM"]).serialize())
)
assert flow().request.url == "http://example.com/"
assert flow().response.text == "Hello, World!"
cff.build_headers_frame(example_response_trailers[1:], flags=["END_STREAM"]).serialize()))
def test_request_trailers(tctx):
@pytest.mark.parametrize("stream", ["stream", ""])
def test_request_trailers(tctx: Context, open_h2_server_conn: Server, stream):
playbook, cff = start_h2_client(tctx)
tctx.server = open_h2_server_conn
def enable_streaming(flow: HTTPFlow):
flow.request.stream = bool(stream)
flow = Placeholder(HTTPFlow)
server = Placeholder(Server)
initial = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
>> DataReceived(tctx.client, cff.build_data_frame(b"Hello, World!").serialize())
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_trailers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestTrailersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, initial)
server_data1 = Placeholder(bytes)
server_data2 = Placeholder(bytes)
(
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers).serialize() +
cff.build_data_frame(b"Hello, World!").serialize()
)
<< http.HttpRequestHeadersHook(flow)
>> reply(side_effect=enable_streaming)
)
frames = decode_frames(initial())
if stream:
playbook << SendData(tctx.server, server_data1)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_trailers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestHook(flow)
>> reply()
<< SendData(tctx.server, server_data2)
)
frames = decode_frames(server_data1.setdefault(b"") + server_data2())
assert [type(x) for x in frames] == [
hyperframe.frame.SettingsFrame,
hyperframe.frame.HeadersFrame,
@ -248,8 +277,6 @@ def test_http2_client_aborts(tctx, stream, when, how):
assert (
playbook
>> reply(side_effect=enable_request_streaming)
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET / HTTP/1.1\r\n"
@ -589,9 +616,11 @@ def test_stream_concurrent_get_connection(tctx):
data = Placeholder(bytes)
assert (playbook
>> DataReceived(tctx.client, cff.build_headers_frame(example_request_headers, flags=["END_STREAM"], stream_id=1).serialize())
>> DataReceived(tctx.client, cff.build_headers_frame(example_request_headers, flags=["END_STREAM"],
stream_id=1).serialize())
<< (o := OpenConnection(server))
>> DataReceived(tctx.client, cff.build_headers_frame(example_request_headers, flags=["END_STREAM"], stream_id=3).serialize())
>> DataReceived(tctx.client, cff.build_headers_frame(example_request_headers, flags=["END_STREAM"],
stream_id=3).serialize())
>> reply(None, to=o, side_effect=make_h2)
<< SendData(server, data)
)

View File

@ -279,9 +279,9 @@ def test_fuzz_h2_response_mutations(chunks):
@pytest.mark.parametrize("example", [(
True, False,
["data_req", "reply_hook_req_headers", "reply_hook_req", "reply_openconn", "data_resp", "data_reqbody",
["data_req", "reply_hook_req_headers", "reply_openconn", "data_resp", "data_reqbody",
"data_respbody", "err_server_rst", "reply_hook_resp_headers"]),
(True, False, ["data_req", "reply_hook_req_headers", "reply_hook_req", "reply_openconn", "err_server_rst",
(True, False, ["data_req", "reply_hook_req_headers", "reply_openconn", "err_server_rst",
"data_reqbody", "reply_hook_error"]),
])
def test_cancel_examples(example):