clean code,

support request trailers when response body is streamed,
support  trailers when  body is streamed
This commit is contained in:
sanlengjingvv 2021-03-15 11:28:10 +08:00
parent a3fd70f240
commit ed7067d36d
5 changed files with 98 additions and 18 deletions

View File

@ -17,10 +17,10 @@ from mitmproxy.proxy.utils import expect
from mitmproxy.utils import human from mitmproxy.utils import human
from mitmproxy.websocket import WebSocketData from mitmproxy.websocket import WebSocketData
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, ResponseTrailers, \ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, RequestTrailers, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError ResponseTrailers, ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \ from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
HttpResponseHook, HttpResponseTrailersHook HttpResponseHook, HttpRequestTrailersHook, HttpResponseTrailersHook
from ._http1 import Http1Client, Http1Server from ._http1 import Http1Client, Http1Server
from ._http2 import Http2Client, Http2Server from ._http2 import Http2Client, Http2Server
from ...context import Context from ...context import Context
@ -129,7 +129,7 @@ class HttpStream(layer.Layer):
self.client_state = self.state_wait_for_request_headers self.client_state = self.state_wait_for_request_headers
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)): elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
yield from self.handle_protocol_error(event) yield from self.handle_protocol_error(event)
elif isinstance(event, (RequestHeaders, RequestData, RequestEndOfMessage)): elif isinstance(event, (RequestHeaders, RequestData, RequestTrailers, RequestEndOfMessage)):
yield from self.client_state(event) yield from self.client_state(event)
else: else:
yield from self.server_state(event) yield from self.server_state(event)
@ -221,11 +221,15 @@ class HttpStream(layer.Layer):
) )
self.client_state = self.state_stream_request_body self.client_state = self.state_stream_request_body
@expect(RequestData, RequestEndOfMessage) @expect(RequestData, RequestTrailers, RequestEndOfMessage)
def state_stream_request_body(self, event: Union[RequestData, RequestEndOfMessage]) -> layer.CommandGenerator[None]: def state_stream_request_body(self, event: Union[RequestData, RequestEndOfMessage]) -> layer.CommandGenerator[None]:
if isinstance(event, RequestData): if isinstance(event, RequestData):
if callable(self.flow.request.stream): if callable(self.flow.request.stream):
event.data = self.flow.request.stream(event.data) event.data = self.flow.request.stream(event.data)
elif isinstance(event, RequestTrailers):
assert self.flow.request
self.flow.request.trailers = event.trailers
yield HttpRequestTrailersHook(self.flow)
elif isinstance(event, RequestEndOfMessage): elif isinstance(event, RequestEndOfMessage):
self.flow.request.timestamp_end = time.time() self.flow.request.timestamp_end = time.time()
self.client_state = self.state_done self.client_state = self.state_done
@ -240,12 +244,18 @@ class HttpStream(layer.Layer):
for evt in self._paused_event_queue: for evt in self._paused_event_queue:
if isinstance(evt, ResponseProtocolError): if isinstance(evt, ResponseProtocolError):
return return
if self.flow.request.trailers:
yield SendHttp(RequestTrailers(self.stream_id, self.flow.request.trailers), self.context.server)
yield SendHttp(event, self.context.server) yield SendHttp(event, self.context.server)
@expect(RequestData, RequestEndOfMessage) @expect(RequestData, RequestTrailers, RequestEndOfMessage)
def state_consume_request_body(self, event: events.Event) -> layer.CommandGenerator[None]: def state_consume_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, RequestData): if isinstance(event, RequestData):
self.request_body_buf += event.data self.request_body_buf += event.data
elif isinstance(event, RequestTrailers):
assert self.flow.request
self.flow.request.trailers = event.trailers
yield HttpRequestTrailersHook(self.flow)
elif isinstance(event, RequestEndOfMessage): elif isinstance(event, RequestEndOfMessage):
self.flow.request.timestamp_end = time.time() self.flow.request.timestamp_end = time.time()
self.flow.request.data.content = self.request_body_buf self.flow.request.data.content = self.request_body_buf
@ -271,6 +281,8 @@ class HttpStream(layer.Layer):
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request, not content), self.context.server) yield SendHttp(RequestHeaders(self.stream_id, self.flow.request, not content), self.context.server)
if content: if content:
yield SendHttp(RequestData(self.stream_id, content), self.context.server) yield SendHttp(RequestData(self.stream_id, content), self.context.server)
if self.flow.request.trailers:
yield SendHttp(RequestTrailers(self.stream_id, self.flow.request.trailers), self.context.server)
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server) yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
@expect(ResponseHeaders) @expect(ResponseHeaders)
@ -285,7 +297,7 @@ class HttpStream(layer.Layer):
else: else:
self.server_state = self.state_consume_response_body self.server_state = self.state_consume_response_body
@expect(ResponseData, ResponseEndOfMessage) @expect(ResponseData, ResponseTrailers, ResponseEndOfMessage)
def state_stream_response_body(self, event: events.Event) -> layer.CommandGenerator[None]: def state_stream_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.flow.response assert self.flow.response
if isinstance(event, ResponseData): if isinstance(event, ResponseData):
@ -294,6 +306,10 @@ class HttpStream(layer.Layer):
else: else:
data = event.data data = event.data
yield SendHttp(ResponseData(self.stream_id, data), self.context.client) yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
elif isinstance(event, ResponseTrailers):
assert self.flow.response
self.flow.response.trailers = event.trailers
yield HttpResponseTrailersHook(self.flow)
elif isinstance(event, ResponseEndOfMessage): elif isinstance(event, ResponseEndOfMessage):
yield from self.send_response(already_streamed=True) yield from self.send_response(already_streamed=True)
@ -341,7 +357,7 @@ class HttpStream(layer.Layer):
if content: if content:
yield SendHttp(ResponseData(self.stream_id, content), self.context.client) yield SendHttp(ResponseData(self.stream_id, content), self.context.client)
if self.flow.response.trailers: if self.flow.response.trailers:
yield SendHttp(ResponseTrailers(self.stream_id, self.flow.response.trailers, end_stream=True), self.context.client) yield SendHttp(ResponseTrailers(self.stream_id, self.flow.response.trailers), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client) yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)

View File

@ -45,10 +45,14 @@ class ResponseData(HttpEvent):
self.data = data self.data = data
@dataclass
class RequestTrailers(HttpEvent):
trailers: http.Headers
@dataclass @dataclass
class ResponseTrailers(HttpEvent): class ResponseTrailers(HttpEvent):
trailers: http.Headers trailers: http.Headers
end_stream: bool = True
@dataclass @dataclass
@ -92,6 +96,7 @@ __all__ = [
"RequestEndOfMessage", "RequestEndOfMessage",
"ResponseHeaders", "ResponseHeaders",
"ResponseData", "ResponseData",
"RequestTrailers",
"ResponseTrailers", "ResponseTrailers",
"ResponseEndOfMessage", "ResponseEndOfMessage",
"RequestProtocolError", "RequestProtocolError",

View File

@ -45,10 +45,29 @@ class HttpResponseHook(commands.StartHook):
flow: http.HTTPFlow flow: http.HTTPFlow
@dataclass
class HttpRequestTrailersHook(commands.StartHook):
"""
The HTTP request trailers has been read.
HTTP trailers are a rarely-used feature in the HTTP specification
which allows peers to send additional headers after the message body.
This is useful for metadata that is dynamically generated while
the message body is sent, for example a digital signature
or post-processing status.
"""
name = "requesttrailers"
flow: http.HTTPFlow
@dataclass @dataclass
class HttpResponseTrailersHook(commands.StartHook): class HttpResponseTrailersHook(commands.StartHook):
""" """
The HTTP response trailers has been read. The HTTP response trailers has been read.
HTTP trailers are a rarely-used feature in the HTTP specification
which allows peers to send additional headers after the message body.
This is useful for metadata that is dynamically generated while
the message body is sent, for example a digital signature
or post-processing status.
""" """
name = "responsetrailers" name = "responsetrailers"
flow: http.HTTPFlow flow: http.HTTPFlow

View File

@ -17,7 +17,7 @@ from mitmproxy.connection import Connection
from mitmproxy.net.http import status_codes, url from mitmproxy.net.http import status_codes, url
from mitmproxy.utils import human from mitmproxy.utils import human
from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseTrailers, ResponseProtocolError ResponseEndOfMessage, ResponseHeaders, RequestTrailers, ResponseTrailers, ResponseProtocolError
from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error
from ._http_h2 import BufferedH2Connection, H2ConnectionLogger from ._http_h2 import BufferedH2Connection, H2ConnectionLogger
from ...commands import CloseConnection, Log, SendData from ...commands import CloseConnection, Log, SendData
@ -97,10 +97,10 @@ class Http2Connection(HttpConnection):
assert isinstance(event, (RequestData, ResponseData)) assert isinstance(event, (RequestData, ResponseData))
if self.is_open_for_us(event.stream_id): if self.is_open_for_us(event.stream_id):
self.h2_conn.send_data(event.stream_id, event.data) self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, ResponseTrailers): elif isinstance(event, (RequestTrailers, ResponseTrailers)):
if self.is_open_for_us(event.stream_id): if self.is_open_for_us(event.stream_id):
trailers = [*event.trailers.fields] trailers = [*event.trailers.fields]
self.h2_conn.send_headers(event.stream_id, trailers, event.end_stream) self.h2_conn.send_headers(event.stream_id, trailers, end_stream=True)
elif isinstance(event, self.SendEndOfMessage): elif isinstance(event, self.SendEndOfMessage):
if self.is_open_for_us(event.stream_id): if self.is_open_for_us(event.stream_id):
self.h2_conn.end_stream(event.stream_id) self.h2_conn.end_stream(event.stream_id)
@ -220,7 +220,7 @@ class Http2Connection(HttpConnection):
elif isinstance(event, h2.events.PingAckReceived): elif isinstance(event, h2.events.PingAckReceived):
pass pass
elif isinstance(event, h2.events.TrailersReceived): elif isinstance(event, h2.events.TrailersReceived):
yield Log("Received HTTP/2 request trailers, which are currently unimplemented and silently discarded", "error") pass
elif isinstance(event, h2.events.PushedStreamReceived): elif isinstance(event, h2.events.PushedStreamReceived):
yield Log("Received HTTP/2 push promise, even though we signalled no support.", "error") yield Log("Received HTTP/2 push promise, even though we signalled no support.", "error")
elif isinstance(event, h2.events.UnknownFrameReceived): elif isinstance(event, h2.events.UnknownFrameReceived):
@ -326,6 +326,10 @@ class Http2Server(Http2Connection):
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
yield ReceiveHttp(RequestHeaders(event.stream_id, request, end_stream=bool(event.stream_ended))) yield ReceiveHttp(RequestHeaders(event.stream_id, request, end_stream=bool(event.stream_ended)))
return False return False
elif isinstance(event, h2.events.TrailersReceived):
trailers = http.Headers(event.headers)
yield ReceiveHttp(RequestTrailers(event.stream_id, trailers))
return False
else: else:
return (yield from super().handle_h2_event(event)) return (yield from super().handle_h2_event(event))
@ -453,8 +457,8 @@ class Http2Client(Http2Connection):
yield ReceiveHttp(ResponseHeaders(event.stream_id, response, bool(event.stream_ended))) yield ReceiveHttp(ResponseHeaders(event.stream_id, response, bool(event.stream_ended)))
return False return False
elif isinstance(event, h2.events.TrailersReceived): elif isinstance(event, h2.events.TrailersReceived):
pseudo_trailers, trailers = split_pseudo_headers(event.headers) trailers = http.Headers(event.headers)
yield ReceiveHttp(ResponseTrailers(event.stream_id, trailers, bool(event.stream_ended))) yield ReceiveHttp(ResponseTrailers(event.stream_id, trailers))
return False return False
elif isinstance(event, h2.events.RequestReceived): elif isinstance(event, h2.events.RequestReceived):
yield from self.protocol_error(f"HTTP/2 protocol error: received request from server") yield from self.protocol_error(f"HTTP/2 protocol error: received request from server")

View File

@ -30,9 +30,14 @@ example_response_headers = (
(b':status', b'200'), (b':status', b'200'),
) )
example_request_trailers = (
(b'req-trailer-a', b'a'),
(b'req-trailer-b', b'b')
)
example_response_trailers = ( example_response_trailers = (
(b'my-trailer-b', b'0'), (b'resp-trailer-a', b'a'),
(b'my-trailer-b', b'0') (b'resp-trailer-b', b'b')
) )
@ -108,7 +113,7 @@ def test_simple(tctx):
assert flow().response.text == "Hello, World!" assert flow().response.text == "Hello, World!"
def test_trailers(tctx): def test_response_trailers(tctx):
playbook, cff = start_h2_client(tctx) playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow) flow = Placeholder(HTTPFlow)
server = Placeholder(Server) server = Placeholder(Server)
@ -152,6 +157,37 @@ def test_trailers(tctx):
assert flow().response.text == "Hello, World!" assert flow().response.text == "Hello, World!"
def test_request_trailers(tctx):
playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow)
server = Placeholder(Server)
initial = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
>> DataReceived(tctx.client, cff.build_data_frame(b"Hello, World!").serialize())
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_trailers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestTrailersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, initial)
)
frames = decode_frames(initial())
assert [type(x) for x in frames] == [
hyperframe.frame.SettingsFrame,
hyperframe.frame.HeadersFrame,
hyperframe.frame.DataFrame,
hyperframe.frame.HeadersFrame,
]
def test_upstream_error(tctx): def test_upstream_error(tctx):
playbook, cff = start_h2_client(tctx) playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow) flow = Placeholder(HTTPFlow)