support http/2 response trailers

This commit is contained in:
sanlengjingvv 2021-03-10 09:51:11 +08:00
parent e40bf0251d
commit 76d7ee3a2f
5 changed files with 82 additions and 5 deletions

View File

@ -17,10 +17,10 @@ from mitmproxy.proxy.utils import expect
from mitmproxy.utils import human
from mitmproxy.websocket import WebSocketData
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, ResponseTrailers, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
HttpResponseHook
HttpResponseHook, HttpResponseTrailersHook
from ._http1 import Http1Client, Http1Server
from ._http2 import Http2Client, Http2Server
from ...context import Context
@ -297,10 +297,13 @@ class HttpStream(layer.Layer):
elif isinstance(event, ResponseEndOfMessage):
yield from self.send_response(already_streamed=True)
@expect(ResponseData, ResponseEndOfMessage)
@expect(ResponseData, ResponseTrailers, ResponseEndOfMessage)
def state_consume_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, ResponseData):
self.response_body_buf += event.data
elif isinstance(event, ResponseTrailers):
self.flow.response.trailers = event.trailers
yield HttpResponseTrailersHook(self.flow)
elif isinstance(event, ResponseEndOfMessage):
assert self.flow.response
self.flow.response.data.content = self.response_body_buf
@ -336,6 +339,8 @@ class HttpStream(layer.Layer):
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response, not content), self.context.client)
if content:
yield SendHttp(ResponseData(self.stream_id, content), self.context.client)
if self.flow.response.trailers:
yield SendHttp(ResponseTrailers(self.stream_id, self.flow.response.trailers, end_stream=True), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)

View File

@ -45,6 +45,12 @@ class ResponseData(HttpEvent):
self.data = data
@dataclass
class ResponseTrailers(HttpEvent):
trailers: http.Headers
end_stream: bool = True
@dataclass
class RequestEndOfMessage(HttpEvent):
def __init__(self, stream_id: int):
@ -86,6 +92,7 @@ __all__ = [
"RequestEndOfMessage",
"ResponseHeaders",
"ResponseData",
"ResponseTrailers",
"ResponseEndOfMessage",
"RequestProtocolError",
"ResponseProtocolError",

View File

@ -44,6 +44,13 @@ class HttpResponseHook(commands.StartHook):
name = "response"
flow: http.HTTPFlow
@dataclass
class HttpResponseTrailersHook(commands.StartHook):
"""
The HTTP response trailers has been read.
"""
name = "responsetrailers"
flow: http.HTTPFlow
@dataclass
class HttpErrorHook(commands.StartHook):

View File

@ -17,7 +17,7 @@ from mitmproxy.connection import Connection
from mitmproxy.net.http import status_codes, url
from mitmproxy.utils import human
from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
ResponseEndOfMessage, ResponseHeaders, ResponseTrailers, ResponseProtocolError
from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error
from ._http_h2 import BufferedH2Connection, H2ConnectionLogger
from ...commands import CloseConnection, Log, SendData
@ -97,6 +97,13 @@ class Http2Connection(HttpConnection):
assert isinstance(event, (RequestData, ResponseData))
if self.is_open_for_us(event.stream_id):
self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, ResponseTrailers):
if self.is_open_for_us(event.stream_id):
trailers = [
*event.trailers.fields
]
r = event.trailers.fields
self.h2_conn.send_headers(event.stream_id, trailers, event.end_stream)
elif isinstance(event, self.SendEndOfMessage):
if self.is_open_for_us(event.stream_id):
self.h2_conn.end_stream(event.stream_id)
@ -216,7 +223,7 @@ class Http2Connection(HttpConnection):
elif isinstance(event, h2.events.PingAckReceived):
pass
elif isinstance(event, h2.events.TrailersReceived):
yield Log("Received HTTP/2 trailers, which are currently unimplemented and silently discarded", "error")
yield Log("Received HTTP/2 request trailers, which are currently unimplemented and silently discarded", "error")
elif isinstance(event, h2.events.PushedStreamReceived):
yield Log("Received HTTP/2 push promise, even though we signalled no support.", "error")
elif isinstance(event, h2.events.UnknownFrameReceived):
@ -448,6 +455,9 @@ class Http2Client(Http2Connection):
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
yield ReceiveHttp(ResponseHeaders(event.stream_id, response, bool(event.stream_ended)))
return False
elif isinstance(event, h2.events.TrailersReceived):
pseudo_trailers, trailers = split_pseudo_headers(event.headers)
yield ReceiveHttp(ResponseTrailers(event.stream_id, trailers, bool(event.stream_ended)))
elif isinstance(event, h2.events.RequestReceived):
yield from self.protocol_error(f"HTTP/2 protocol error: received request from server")
return True

View File

@ -30,6 +30,10 @@ example_response_headers = (
(b':status', b'200'),
)
example_response_trailers = (
(b'my-trailer-b', b'0'),
(b'my-trailer-b', b'0')
)
def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
# swallow preamble
@ -103,6 +107,50 @@ def test_simple(tctx):
assert flow().response.text == "Hello, World!"
def test_trailers(tctx):
playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow)
server = Placeholder(Server)
initial = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, initial)
)
frames = decode_frames(initial())
assert [type(x) for x in frames] == [
hyperframe.frame.SettingsFrame,
hyperframe.frame.HeadersFrame,
]
sff = FrameFactory()
assert (
playbook
# a conforming h2 server would send settings first, we disregard this for now.
>> DataReceived(server, sff.build_headers_frame(example_response_headers).serialize())
<< http.HttpResponseHeadersHook(flow)
>> reply()
>> DataReceived(server, sff.build_data_frame(b"Hello, World!").serialize())
>> DataReceived(server, sff.build_headers_frame(example_response_trailers, flags=["END_STREAM"]).serialize())
<< http.HttpResponseTrailersHook(flow)
>> reply()
<< http.HttpResponseHook(flow)
>> reply()
<< SendData(tctx.client,
cff.build_headers_frame(example_response_headers).serialize() +
cff.build_data_frame(b"Hello, World!").serialize() +
cff.build_headers_frame(example_response_trailers, flags=["END_STREAM"]).serialize())
)
assert flow().request.url == "http://example.com/"
assert flow().response.text == "Hello, World!"
def test_upstream_error(tctx):
playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow)