mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
fix streaming
This commit is contained in:
parent
2dfba2105b
commit
63844df343
@ -25,32 +25,101 @@ from netlib.http.http2 import HTTP2Protocol
|
||||
# TODO: The HTTP2 layer is missing multiplexing, which requires a major rewrite.
|
||||
|
||||
|
||||
class Http1Layer(Layer):
|
||||
class _HttpLayer(Layer):
|
||||
supports_streaming = False
|
||||
|
||||
def read_request(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def send_request(self, request):
|
||||
raise NotImplementedError()
|
||||
|
||||
def read_response(self, request_method):
|
||||
raise NotImplementedError()
|
||||
|
||||
def send_response(self, response):
|
||||
raise NotImplementedError()
|
||||
|
||||
class _StreamingHttpLayer(_HttpLayer):
|
||||
supports_streaming = True
|
||||
|
||||
def read_response_headers(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def read_response_body(self, headers, request_method, response_code, max_chunk_size=None):
|
||||
raise NotImplementedError()
|
||||
yield "this is a generator"
|
||||
|
||||
def send_response_headers(self, response):
|
||||
raise NotImplementedError
|
||||
|
||||
def send_response_body(self, response, chunks):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Http1Layer(_StreamingHttpLayer):
|
||||
|
||||
def __init__(self, ctx, mode):
|
||||
super(Http1Layer, self).__init__(ctx)
|
||||
self.mode = mode
|
||||
self.client_protocol = HTTP1Protocol(self.client_conn)
|
||||
self.server_protocol = HTTP1Protocol(self.server_conn)
|
||||
|
||||
def read_from_client(self):
|
||||
def read_request(self):
|
||||
return HTTPRequest.from_protocol(
|
||||
self.client_protocol,
|
||||
body_size_limit=self.config.body_size_limit
|
||||
)
|
||||
|
||||
def read_from_server(self, request_method):
|
||||
def send_request(self, request):
|
||||
self.server_conn.send(self.server_protocol.assemble(request))
|
||||
|
||||
def read_response(self, request_method):
|
||||
return HTTPResponse.from_protocol(
|
||||
self.server_protocol,
|
||||
request_method,
|
||||
request_method=request_method,
|
||||
body_size_limit=self.config.body_size_limit,
|
||||
include_body=False,
|
||||
include_body=True
|
||||
)
|
||||
|
||||
def send_to_client(self, message):
|
||||
self.client_conn.send(self.client_protocol.assemble(message))
|
||||
def send_response(self, response):
|
||||
self.client_conn.send(self.client_protocol.assemble(response))
|
||||
|
||||
def send_to_server(self, message):
|
||||
self.server_conn.send(self.server_protocol.assemble(message))
|
||||
def read_response_headers(self):
|
||||
return HTTPResponse.from_protocol(
|
||||
self.server_protocol,
|
||||
request_method=None, # does not matter if we don't read the body.
|
||||
body_size_limit=self.config.body_size_limit,
|
||||
include_body=False
|
||||
)
|
||||
|
||||
def read_response_body(self, headers, request_method, response_code, max_chunk_size=None):
|
||||
return self.server_protocol.read_http_body_chunked(
|
||||
headers,
|
||||
self.config.body_size_limit,
|
||||
request_method,
|
||||
response_code,
|
||||
False,
|
||||
max_chunk_size
|
||||
)
|
||||
|
||||
def send_response_headers(self, response):
|
||||
h = self.client_protocol._assemble_response_first_line(response)
|
||||
self.client_conn.wfile.write(h+"\r\n")
|
||||
h = self.client_protocol._assemble_response_headers(
|
||||
response,
|
||||
preserve_transfer_encoding=True
|
||||
)
|
||||
self.client_conn.send(h+"\r\n")
|
||||
|
||||
def send_response_body(self, response, chunks):
|
||||
if self.client_protocol.has_chunked_encoding(response.headers):
|
||||
chunks = (
|
||||
"%d\r\n%s\r\n" % (len(chunk), chunk)
|
||||
for chunk in chunks
|
||||
)
|
||||
for chunk in chunks:
|
||||
self.client_conn.send(chunk)
|
||||
|
||||
def connect(self):
|
||||
self.ctx.connect()
|
||||
@ -69,14 +138,14 @@ class Http1Layer(Layer):
|
||||
layer()
|
||||
|
||||
|
||||
class Http2Layer(Layer):
|
||||
class Http2Layer(_HttpLayer):
|
||||
def __init__(self, ctx, mode):
|
||||
super(Http2Layer, self).__init__(ctx)
|
||||
self.mode = mode
|
||||
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
|
||||
def read_from_client(self):
|
||||
def read_request(self):
|
||||
request = HTTPRequest.from_protocol(
|
||||
self.client_protocol,
|
||||
body_size_limit=self.config.body_size_limit
|
||||
@ -84,23 +153,23 @@ class Http2Layer(Layer):
|
||||
self._stream_id = request.stream_id
|
||||
return request
|
||||
|
||||
def read_from_server(self, request_method):
|
||||
def send_request(self, message):
|
||||
# TODO: implement flow control and WINDOW_UPDATE frames
|
||||
self.server_conn.send(self.server_protocol.assemble(message))
|
||||
|
||||
def read_response(self, request_method):
|
||||
return HTTPResponse.from_protocol(
|
||||
self.server_protocol,
|
||||
request_method,
|
||||
request_method=request_method,
|
||||
body_size_limit=self.config.body_size_limit,
|
||||
include_body=True,
|
||||
stream_id=self._stream_id
|
||||
)
|
||||
|
||||
def send_to_client(self, message):
|
||||
def send_response(self, message):
|
||||
# TODO: implement flow control and WINDOW_UPDATE frames
|
||||
self.client_conn.send(self.client_protocol.assemble(message))
|
||||
|
||||
def send_to_server(self, message):
|
||||
# TODO: implement flow control and WINDOW_UPDATE frames
|
||||
self.server_conn.send(self.server_protocol.assemble(message))
|
||||
|
||||
def connect(self):
|
||||
self.ctx.connect()
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
@ -122,7 +191,7 @@ class Http2Layer(Layer):
|
||||
layer()
|
||||
|
||||
def handle_unexpected_frame(self, frm):
|
||||
print(frm.human_readable())
|
||||
self.log("Unexpected HTTP2 Frame: %s" % frm.human_readable(), "info")
|
||||
|
||||
|
||||
def make_error_response(status_code, message, headers=None):
|
||||
@ -204,13 +273,13 @@ class UpstreamConnectLayer(Layer):
|
||||
def connect(self):
|
||||
if not self.server_conn:
|
||||
self.ctx.connect()
|
||||
self.send_to_server(self.connect_request)
|
||||
self.send_request(self.connect_request)
|
||||
else:
|
||||
pass # swallow the message
|
||||
|
||||
def reconnect(self):
|
||||
self.ctx.reconnect()
|
||||
self.send_to_server(self.connect_request)
|
||||
self.send_request(self.connect_request)
|
||||
|
||||
def set_server(self, address, server_tls=None, sni=None, depth=1):
|
||||
if depth == 1:
|
||||
@ -240,7 +309,7 @@ class HttpLayer(Layer):
|
||||
flow = HTTPFlow(self.client_conn, self.server_conn, live=self)
|
||||
|
||||
try:
|
||||
request = self.read_from_client()
|
||||
request = self.read_request()
|
||||
except tcp.NetLibError:
|
||||
# don't throw an error for disconnects that happen
|
||||
# before/between requests.
|
||||
@ -280,7 +349,7 @@ class HttpLayer(Layer):
|
||||
|
||||
except (HttpErrorConnClosed, NetLibError, HttpError, ProtocolException) as e:
|
||||
try:
|
||||
self.send_to_client(make_error_response(
|
||||
self.send_response(make_error_response(
|
||||
getattr(e, "code", 502),
|
||||
repr(e)
|
||||
))
|
||||
@ -295,7 +364,7 @@ class HttpLayer(Layer):
|
||||
|
||||
def handle_regular_mode_connect(self, request):
|
||||
self.set_server((request.host, request.port))
|
||||
self.send_to_client(make_connect_response(request.httpversion))
|
||||
self.send_response(make_connect_response(request.httpversion))
|
||||
layer = self.ctx.next_layer(self)
|
||||
layer()
|
||||
|
||||
@ -334,44 +403,33 @@ class HttpLayer(Layer):
|
||||
return close_connection
|
||||
|
||||
def send_response_to_client(self, flow):
|
||||
if not flow.response.stream:
|
||||
if not (self.supports_streaming and flow.response.stream):
|
||||
# no streaming:
|
||||
# we already received the full response from the server and can
|
||||
# send it to the client straight away.
|
||||
self.send_to_client(flow.response)
|
||||
self.send_response(flow.response)
|
||||
else:
|
||||
# streaming:
|
||||
# First send the headers and then transfer the response
|
||||
# incrementally:
|
||||
h = self.client_protocol._assemble_response_first_line(flow.response)
|
||||
self.send_to_client(h + "\r\n")
|
||||
h = self.client_protocol._assemble_response_headers(flow.response, preserve_transfer_encoding=True)
|
||||
self.send_to_client(h + "\r\n")
|
||||
|
||||
chunks = self.client_protocol.read_http_body_chunked(
|
||||
flow.response.headers,
|
||||
self.config.body_size_limit,
|
||||
flow.request.method,
|
||||
flow.response.code,
|
||||
False,
|
||||
4096
|
||||
# First send the headers and then transfer the response incrementally
|
||||
self.send_response_headers(flow.response)
|
||||
chunks = self.read_response_body(
|
||||
flow.response.headers,
|
||||
flow.request.method,
|
||||
flow.response.code,
|
||||
max_chunk_size=4096
|
||||
)
|
||||
|
||||
if callable(flow.response.stream):
|
||||
chunks = flow.response.stream(chunks)
|
||||
|
||||
for chunk in chunks:
|
||||
for part in chunk:
|
||||
# TODO: That's going to fail.
|
||||
self.send_to_client(part)
|
||||
self.client_conn.wfile.flush()
|
||||
|
||||
self.send_response_body(flow.response, chunks)
|
||||
flow.response.timestamp_end = utils.timestamp()
|
||||
|
||||
def get_response_from_server(self, flow):
|
||||
def get_response():
|
||||
self.send_to_server(flow.request)
|
||||
flow.response = self.read_from_server(flow.request.method)
|
||||
self.send_request(flow.request)
|
||||
if self.supports_streaming:
|
||||
flow.response = self.read_response_headers()
|
||||
else:
|
||||
flow.response = self.read_response()
|
||||
|
||||
try:
|
||||
get_response()
|
||||
@ -400,18 +458,15 @@ class HttpLayer(Layer):
|
||||
if flow is None or flow == KILL:
|
||||
raise Kill()
|
||||
|
||||
if isinstance(self.ctx, Http2Layer):
|
||||
pass # streaming is not implemented for http2 yet.
|
||||
elif flow.response.stream:
|
||||
flow.response.content = CONTENT_MISSING
|
||||
else:
|
||||
flow.response.content = self.server_protocol.read_http_body(
|
||||
flow.response.headers,
|
||||
self.config.body_size_limit,
|
||||
flow.request.method,
|
||||
flow.response.code,
|
||||
False
|
||||
)
|
||||
if self.supports_streaming:
|
||||
if flow.response.stream:
|
||||
flow.response.content = CONTENT_MISSING
|
||||
else:
|
||||
flow.response.content = "".join(self.read_response_body(
|
||||
flow.response.headers,
|
||||
flow.request.method,
|
||||
flow.response.code
|
||||
))
|
||||
flow.response.timestamp_end = utils.timestamp()
|
||||
|
||||
# no further manipulation of self.server_conn beyond this point
|
||||
@ -480,14 +535,14 @@ class HttpLayer(Layer):
|
||||
if self.server_conn.tls_established:
|
||||
self.reconnect()
|
||||
|
||||
self.send_to_server(make_connect_request(address))
|
||||
self.send_request(make_connect_request(address))
|
||||
tls_layer = TlsLayer(self, False, True)
|
||||
tls_layer._establish_tls_with_server()
|
||||
"""
|
||||
|
||||
def validate_request(self, request):
|
||||
if request.form_in == "absolute" and request.scheme != "http":
|
||||
self.send_to_client(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
|
||||
self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
|
||||
raise HttpException("Invalid request scheme: %s" % request.scheme)
|
||||
|
||||
expected_request_forms = {
|
||||
@ -501,7 +556,7 @@ class HttpLayer(Layer):
|
||||
err_message = "Invalid HTTP request form (expected: %s, got: %s)" % (
|
||||
" or ".join(allowed_request_forms), request.form_in
|
||||
)
|
||||
self.send_to_client(make_error_response(400, err_message))
|
||||
self.send_response(make_error_response(400, err_message))
|
||||
raise HttpException(err_message)
|
||||
|
||||
if self.mode == "regular":
|
||||
@ -512,7 +567,7 @@ class HttpLayer(Layer):
|
||||
if self.config.authenticator.authenticate(request.headers):
|
||||
self.config.authenticator.clean(request.headers)
|
||||
else:
|
||||
self.send_to_client(make_error_response(
|
||||
self.send_response(make_error_response(
|
||||
407,
|
||||
"Proxy Authentication Required",
|
||||
odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()])
|
||||
@ -552,10 +607,7 @@ class RequestReplayThread(threading.Thread):
|
||||
if not self.flow.response:
|
||||
# In all modes, we directly connect to the server displayed
|
||||
if self.config.mode == "upstream":
|
||||
# FIXME
|
||||
server_address = self.config.mode.get_upstream_server(
|
||||
self.flow.client_conn
|
||||
)[2:]
|
||||
server_address = self.config.upstream_server.address
|
||||
server = ServerConnection(server_address)
|
||||
server.connect()
|
||||
protocol = HTTP1Protocol(server)
|
||||
|
@ -1,6 +1,6 @@
|
||||
def modify(chunks):
|
||||
for prefix, content, suffix in chunks:
|
||||
yield prefix, content.replace("foo", "bar"), suffix
|
||||
for chunk in chunks:
|
||||
yield chunk.replace("foo", "bar")
|
||||
|
||||
|
||||
def responseheaders(context, flow):
|
||||
|
Loading…
Reference in New Issue
Block a user