clean code,

support request trailers when response body is streamed,
support  trailers when  body is streamed
This commit is contained in:
sanlengjingvv 2021-03-15 11:28:10 +08:00
parent a3fd70f240
commit ed7067d36d
5 changed files with 98 additions and 18 deletions

View File

@ -17,10 +17,10 @@ from mitmproxy.proxy.utils import expect
from mitmproxy.utils import human
from mitmproxy.websocket import WebSocketData
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, ResponseTrailers, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, RequestTrailers, \
ResponseTrailers, ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
HttpResponseHook, HttpResponseTrailersHook
HttpResponseHook, HttpRequestTrailersHook, HttpResponseTrailersHook
from ._http1 import Http1Client, Http1Server
from ._http2 import Http2Client, Http2Server
from ...context import Context
@ -129,7 +129,7 @@ class HttpStream(layer.Layer):
self.client_state = self.state_wait_for_request_headers
elif isinstance(event, (RequestProtocolError, ResponseProtocolError)):
yield from self.handle_protocol_error(event)
elif isinstance(event, (RequestHeaders, RequestData, RequestEndOfMessage)):
elif isinstance(event, (RequestHeaders, RequestData, RequestTrailers, RequestEndOfMessage)):
yield from self.client_state(event)
else:
yield from self.server_state(event)
@ -221,11 +221,15 @@ class HttpStream(layer.Layer):
)
self.client_state = self.state_stream_request_body
@expect(RequestData, RequestEndOfMessage)
@expect(RequestData, RequestTrailers, RequestEndOfMessage)
def state_stream_request_body(self, event: Union[RequestData, RequestEndOfMessage]) -> layer.CommandGenerator[None]:
if isinstance(event, RequestData):
if callable(self.flow.request.stream):
event.data = self.flow.request.stream(event.data)
elif isinstance(event, RequestTrailers):
assert self.flow.request
self.flow.request.trailers = event.trailers
yield HttpRequestTrailersHook(self.flow)
elif isinstance(event, RequestEndOfMessage):
self.flow.request.timestamp_end = time.time()
self.client_state = self.state_done
@ -240,12 +244,18 @@ class HttpStream(layer.Layer):
for evt in self._paused_event_queue:
if isinstance(evt, ResponseProtocolError):
return
if self.flow.request.trailers:
yield SendHttp(RequestTrailers(self.stream_id, self.flow.request.trailers), self.context.server)
yield SendHttp(event, self.context.server)
@expect(RequestData, RequestEndOfMessage)
@expect(RequestData, RequestTrailers, RequestEndOfMessage)
def state_consume_request_body(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, RequestData):
self.request_body_buf += event.data
elif isinstance(event, RequestTrailers):
assert self.flow.request
self.flow.request.trailers = event.trailers
yield HttpRequestTrailersHook(self.flow)
elif isinstance(event, RequestEndOfMessage):
self.flow.request.timestamp_end = time.time()
self.flow.request.data.content = self.request_body_buf
@ -271,6 +281,8 @@ class HttpStream(layer.Layer):
yield SendHttp(RequestHeaders(self.stream_id, self.flow.request, not content), self.context.server)
if content:
yield SendHttp(RequestData(self.stream_id, content), self.context.server)
if self.flow.request.trailers:
yield SendHttp(RequestTrailers(self.stream_id, self.flow.request.trailers), self.context.server)
yield SendHttp(RequestEndOfMessage(self.stream_id), self.context.server)
@expect(ResponseHeaders)
@ -285,7 +297,7 @@ class HttpStream(layer.Layer):
else:
self.server_state = self.state_consume_response_body
@expect(ResponseData, ResponseEndOfMessage)
@expect(ResponseData, ResponseTrailers, ResponseEndOfMessage)
def state_stream_response_body(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.flow.response
if isinstance(event, ResponseData):
@ -294,6 +306,10 @@ class HttpStream(layer.Layer):
else:
data = event.data
yield SendHttp(ResponseData(self.stream_id, data), self.context.client)
elif isinstance(event, ResponseTrailers):
assert self.flow.response
self.flow.response.trailers = event.trailers
yield HttpResponseTrailersHook(self.flow)
elif isinstance(event, ResponseEndOfMessage):
yield from self.send_response(already_streamed=True)
@ -341,7 +357,7 @@ class HttpStream(layer.Layer):
if content:
yield SendHttp(ResponseData(self.stream_id, content), self.context.client)
if self.flow.response.trailers:
yield SendHttp(ResponseTrailers(self.stream_id, self.flow.response.trailers, end_stream=True), self.context.client)
yield SendHttp(ResponseTrailers(self.stream_id, self.flow.response.trailers), self.context.client)
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)

View File

@ -45,10 +45,14 @@ class ResponseData(HttpEvent):
self.data = data
@dataclass
class RequestTrailers(HttpEvent):
trailers: http.Headers
@dataclass
class ResponseTrailers(HttpEvent):
trailers: http.Headers
end_stream: bool = True
@dataclass
@ -92,6 +96,7 @@ __all__ = [
"RequestEndOfMessage",
"ResponseHeaders",
"ResponseData",
"RequestTrailers",
"ResponseTrailers",
"ResponseEndOfMessage",
"RequestProtocolError",

View File

@ -45,10 +45,29 @@ class HttpResponseHook(commands.StartHook):
flow: http.HTTPFlow
@dataclass
class HttpRequestTrailersHook(commands.StartHook):
"""
The HTTP request trailers has been read.
HTTP trailers are a rarely-used feature in the HTTP specification
which allows peers to send additional headers after the message body.
This is useful for metadata that is dynamically generated while
the message body is sent, for example a digital signature
or post-processing status.
"""
name = "requesttrailers"
flow: http.HTTPFlow
@dataclass
class HttpResponseTrailersHook(commands.StartHook):
"""
The HTTP response trailers has been read.
HTTP trailers are a rarely-used feature in the HTTP specification
which allows peers to send additional headers after the message body.
This is useful for metadata that is dynamically generated while
the message body is sent, for example a digital signature
or post-processing status.
"""
name = "responsetrailers"
flow: http.HTTPFlow

View File

@ -17,7 +17,7 @@ from mitmproxy.connection import Connection
from mitmproxy.net.http import status_codes, url
from mitmproxy.utils import human
from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseTrailers, ResponseProtocolError
ResponseEndOfMessage, ResponseHeaders, RequestTrailers, ResponseTrailers, ResponseProtocolError
from ._base import HttpConnection, HttpEvent, ReceiveHttp, format_error
from ._http_h2 import BufferedH2Connection, H2ConnectionLogger
from ...commands import CloseConnection, Log, SendData
@ -97,10 +97,10 @@ class Http2Connection(HttpConnection):
assert isinstance(event, (RequestData, ResponseData))
if self.is_open_for_us(event.stream_id):
self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, ResponseTrailers):
elif isinstance(event, (RequestTrailers, ResponseTrailers)):
if self.is_open_for_us(event.stream_id):
trailers = [*event.trailers.fields]
self.h2_conn.send_headers(event.stream_id, trailers, event.end_stream)
self.h2_conn.send_headers(event.stream_id, trailers, end_stream=True)
elif isinstance(event, self.SendEndOfMessage):
if self.is_open_for_us(event.stream_id):
self.h2_conn.end_stream(event.stream_id)
@ -220,7 +220,7 @@ class Http2Connection(HttpConnection):
elif isinstance(event, h2.events.PingAckReceived):
pass
elif isinstance(event, h2.events.TrailersReceived):
yield Log("Received HTTP/2 request trailers, which are currently unimplemented and silently discarded", "error")
pass
elif isinstance(event, h2.events.PushedStreamReceived):
yield Log("Received HTTP/2 push promise, even though we signalled no support.", "error")
elif isinstance(event, h2.events.UnknownFrameReceived):
@ -326,6 +326,10 @@ class Http2Server(Http2Connection):
self.streams[event.stream_id] = StreamState.HEADERS_RECEIVED
yield ReceiveHttp(RequestHeaders(event.stream_id, request, end_stream=bool(event.stream_ended)))
return False
elif isinstance(event, h2.events.TrailersReceived):
trailers = http.Headers(event.headers)
yield ReceiveHttp(RequestTrailers(event.stream_id, trailers))
return False
else:
return (yield from super().handle_h2_event(event))
@ -453,8 +457,8 @@ class Http2Client(Http2Connection):
yield ReceiveHttp(ResponseHeaders(event.stream_id, response, bool(event.stream_ended)))
return False
elif isinstance(event, h2.events.TrailersReceived):
pseudo_trailers, trailers = split_pseudo_headers(event.headers)
yield ReceiveHttp(ResponseTrailers(event.stream_id, trailers, bool(event.stream_ended)))
trailers = http.Headers(event.headers)
yield ReceiveHttp(ResponseTrailers(event.stream_id, trailers))
return False
elif isinstance(event, h2.events.RequestReceived):
yield from self.protocol_error(f"HTTP/2 protocol error: received request from server")

View File

@ -30,9 +30,14 @@ example_response_headers = (
(b':status', b'200'),
)
example_request_trailers = (
(b'req-trailer-a', b'a'),
(b'req-trailer-b', b'b')
)
example_response_trailers = (
(b'my-trailer-b', b'0'),
(b'my-trailer-b', b'0')
(b'resp-trailer-a', b'a'),
(b'resp-trailer-b', b'b')
)
@ -108,7 +113,7 @@ def test_simple(tctx):
assert flow().response.text == "Hello, World!"
def test_trailers(tctx):
def test_response_trailers(tctx):
playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow)
server = Placeholder(Server)
@ -152,6 +157,37 @@ def test_trailers(tctx):
assert flow().response.text == "Hello, World!"
def test_request_trailers(tctx):
playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow)
server = Placeholder(Server)
initial = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_headers).serialize())
<< http.HttpRequestHeadersHook(flow)
>> reply()
>> DataReceived(tctx.client, cff.build_data_frame(b"Hello, World!").serialize())
>> DataReceived(tctx.client,
cff.build_headers_frame(example_request_trailers, flags=["END_STREAM"]).serialize())
<< http.HttpRequestTrailersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, initial)
)
frames = decode_frames(initial())
assert [type(x) for x in frames] == [
hyperframe.frame.SettingsFrame,
hyperframe.frame.HeadersFrame,
hyperframe.frame.DataFrame,
hyperframe.frame.HeadersFrame,
]
def test_upstream_error(tctx):
playbook, cff = start_h2_client(tctx)
flow = Placeholder(HTTPFlow)