mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
fix streaming
This commit is contained in:
parent
2dfba2105b
commit
63844df343
@ -25,32 +25,101 @@ from netlib.http.http2 import HTTP2Protocol
|
|||||||
# TODO: The HTTP2 layer is missing multiplexing, which requires a major rewrite.
|
# TODO: The HTTP2 layer is missing multiplexing, which requires a major rewrite.
|
||||||
|
|
||||||
|
|
||||||
class Http1Layer(Layer):
|
class _HttpLayer(Layer):
|
||||||
|
supports_streaming = False
|
||||||
|
|
||||||
|
def read_request(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def send_request(self, request):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def read_response(self, request_method):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def send_response(self, response):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class _StreamingHttpLayer(_HttpLayer):
|
||||||
|
supports_streaming = True
|
||||||
|
|
||||||
|
def read_response_headers(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def read_response_body(self, headers, request_method, response_code, max_chunk_size=None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
yield "this is a generator"
|
||||||
|
|
||||||
|
def send_response_headers(self, response):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def send_response_body(self, response, chunks):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class Http1Layer(_StreamingHttpLayer):
|
||||||
|
|
||||||
def __init__(self, ctx, mode):
|
def __init__(self, ctx, mode):
|
||||||
super(Http1Layer, self).__init__(ctx)
|
super(Http1Layer, self).__init__(ctx)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.client_protocol = HTTP1Protocol(self.client_conn)
|
self.client_protocol = HTTP1Protocol(self.client_conn)
|
||||||
self.server_protocol = HTTP1Protocol(self.server_conn)
|
self.server_protocol = HTTP1Protocol(self.server_conn)
|
||||||
|
|
||||||
def read_from_client(self):
|
def read_request(self):
|
||||||
return HTTPRequest.from_protocol(
|
return HTTPRequest.from_protocol(
|
||||||
self.client_protocol,
|
self.client_protocol,
|
||||||
body_size_limit=self.config.body_size_limit
|
body_size_limit=self.config.body_size_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
def read_from_server(self, request_method):
|
def send_request(self, request):
|
||||||
|
self.server_conn.send(self.server_protocol.assemble(request))
|
||||||
|
|
||||||
|
def read_response(self, request_method):
|
||||||
return HTTPResponse.from_protocol(
|
return HTTPResponse.from_protocol(
|
||||||
self.server_protocol,
|
self.server_protocol,
|
||||||
request_method,
|
request_method=request_method,
|
||||||
body_size_limit=self.config.body_size_limit,
|
body_size_limit=self.config.body_size_limit,
|
||||||
include_body=False,
|
include_body=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_to_client(self, message):
|
def send_response(self, response):
|
||||||
self.client_conn.send(self.client_protocol.assemble(message))
|
self.client_conn.send(self.client_protocol.assemble(response))
|
||||||
|
|
||||||
def send_to_server(self, message):
|
def read_response_headers(self):
|
||||||
self.server_conn.send(self.server_protocol.assemble(message))
|
return HTTPResponse.from_protocol(
|
||||||
|
self.server_protocol,
|
||||||
|
request_method=None, # does not matter if we don't read the body.
|
||||||
|
body_size_limit=self.config.body_size_limit,
|
||||||
|
include_body=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_response_body(self, headers, request_method, response_code, max_chunk_size=None):
|
||||||
|
return self.server_protocol.read_http_body_chunked(
|
||||||
|
headers,
|
||||||
|
self.config.body_size_limit,
|
||||||
|
request_method,
|
||||||
|
response_code,
|
||||||
|
False,
|
||||||
|
max_chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def send_response_headers(self, response):
|
||||||
|
h = self.client_protocol._assemble_response_first_line(response)
|
||||||
|
self.client_conn.wfile.write(h+"\r\n")
|
||||||
|
h = self.client_protocol._assemble_response_headers(
|
||||||
|
response,
|
||||||
|
preserve_transfer_encoding=True
|
||||||
|
)
|
||||||
|
self.client_conn.send(h+"\r\n")
|
||||||
|
|
||||||
|
def send_response_body(self, response, chunks):
|
||||||
|
if self.client_protocol.has_chunked_encoding(response.headers):
|
||||||
|
chunks = (
|
||||||
|
"%d\r\n%s\r\n" % (len(chunk), chunk)
|
||||||
|
for chunk in chunks
|
||||||
|
)
|
||||||
|
for chunk in chunks:
|
||||||
|
self.client_conn.send(chunk)
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self.ctx.connect()
|
self.ctx.connect()
|
||||||
@ -69,14 +138,14 @@ class Http1Layer(Layer):
|
|||||||
layer()
|
layer()
|
||||||
|
|
||||||
|
|
||||||
class Http2Layer(Layer):
|
class Http2Layer(_HttpLayer):
|
||||||
def __init__(self, ctx, mode):
|
def __init__(self, ctx, mode):
|
||||||
super(Http2Layer, self).__init__(ctx)
|
super(Http2Layer, self).__init__(ctx)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame)
|
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||||
|
|
||||||
def read_from_client(self):
|
def read_request(self):
|
||||||
request = HTTPRequest.from_protocol(
|
request = HTTPRequest.from_protocol(
|
||||||
self.client_protocol,
|
self.client_protocol,
|
||||||
body_size_limit=self.config.body_size_limit
|
body_size_limit=self.config.body_size_limit
|
||||||
@ -84,23 +153,23 @@ class Http2Layer(Layer):
|
|||||||
self._stream_id = request.stream_id
|
self._stream_id = request.stream_id
|
||||||
return request
|
return request
|
||||||
|
|
||||||
def read_from_server(self, request_method):
|
def send_request(self, message):
|
||||||
|
# TODO: implement flow control and WINDOW_UPDATE frames
|
||||||
|
self.server_conn.send(self.server_protocol.assemble(message))
|
||||||
|
|
||||||
|
def read_response(self, request_method):
|
||||||
return HTTPResponse.from_protocol(
|
return HTTPResponse.from_protocol(
|
||||||
self.server_protocol,
|
self.server_protocol,
|
||||||
request_method,
|
request_method=request_method,
|
||||||
body_size_limit=self.config.body_size_limit,
|
body_size_limit=self.config.body_size_limit,
|
||||||
include_body=True,
|
include_body=True,
|
||||||
stream_id=self._stream_id
|
stream_id=self._stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_to_client(self, message):
|
def send_response(self, message):
|
||||||
# TODO: implement flow control and WINDOW_UPDATE frames
|
# TODO: implement flow control and WINDOW_UPDATE frames
|
||||||
self.client_conn.send(self.client_protocol.assemble(message))
|
self.client_conn.send(self.client_protocol.assemble(message))
|
||||||
|
|
||||||
def send_to_server(self, message):
|
|
||||||
# TODO: implement flow control and WINDOW_UPDATE frames
|
|
||||||
self.server_conn.send(self.server_protocol.assemble(message))
|
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self.ctx.connect()
|
self.ctx.connect()
|
||||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||||
@ -122,7 +191,7 @@ class Http2Layer(Layer):
|
|||||||
layer()
|
layer()
|
||||||
|
|
||||||
def handle_unexpected_frame(self, frm):
|
def handle_unexpected_frame(self, frm):
|
||||||
print(frm.human_readable())
|
self.log("Unexpected HTTP2 Frame: %s" % frm.human_readable(), "info")
|
||||||
|
|
||||||
|
|
||||||
def make_error_response(status_code, message, headers=None):
|
def make_error_response(status_code, message, headers=None):
|
||||||
@ -204,13 +273,13 @@ class UpstreamConnectLayer(Layer):
|
|||||||
def connect(self):
|
def connect(self):
|
||||||
if not self.server_conn:
|
if not self.server_conn:
|
||||||
self.ctx.connect()
|
self.ctx.connect()
|
||||||
self.send_to_server(self.connect_request)
|
self.send_request(self.connect_request)
|
||||||
else:
|
else:
|
||||||
pass # swallow the message
|
pass # swallow the message
|
||||||
|
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
self.ctx.reconnect()
|
self.ctx.reconnect()
|
||||||
self.send_to_server(self.connect_request)
|
self.send_request(self.connect_request)
|
||||||
|
|
||||||
def set_server(self, address, server_tls=None, sni=None, depth=1):
|
def set_server(self, address, server_tls=None, sni=None, depth=1):
|
||||||
if depth == 1:
|
if depth == 1:
|
||||||
@ -240,7 +309,7 @@ class HttpLayer(Layer):
|
|||||||
flow = HTTPFlow(self.client_conn, self.server_conn, live=self)
|
flow = HTTPFlow(self.client_conn, self.server_conn, live=self)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request = self.read_from_client()
|
request = self.read_request()
|
||||||
except tcp.NetLibError:
|
except tcp.NetLibError:
|
||||||
# don't throw an error for disconnects that happen
|
# don't throw an error for disconnects that happen
|
||||||
# before/between requests.
|
# before/between requests.
|
||||||
@ -280,7 +349,7 @@ class HttpLayer(Layer):
|
|||||||
|
|
||||||
except (HttpErrorConnClosed, NetLibError, HttpError, ProtocolException) as e:
|
except (HttpErrorConnClosed, NetLibError, HttpError, ProtocolException) as e:
|
||||||
try:
|
try:
|
||||||
self.send_to_client(make_error_response(
|
self.send_response(make_error_response(
|
||||||
getattr(e, "code", 502),
|
getattr(e, "code", 502),
|
||||||
repr(e)
|
repr(e)
|
||||||
))
|
))
|
||||||
@ -295,7 +364,7 @@ class HttpLayer(Layer):
|
|||||||
|
|
||||||
def handle_regular_mode_connect(self, request):
|
def handle_regular_mode_connect(self, request):
|
||||||
self.set_server((request.host, request.port))
|
self.set_server((request.host, request.port))
|
||||||
self.send_to_client(make_connect_response(request.httpversion))
|
self.send_response(make_connect_response(request.httpversion))
|
||||||
layer = self.ctx.next_layer(self)
|
layer = self.ctx.next_layer(self)
|
||||||
layer()
|
layer()
|
||||||
|
|
||||||
@ -334,44 +403,33 @@ class HttpLayer(Layer):
|
|||||||
return close_connection
|
return close_connection
|
||||||
|
|
||||||
def send_response_to_client(self, flow):
|
def send_response_to_client(self, flow):
|
||||||
if not flow.response.stream:
|
if not (self.supports_streaming and flow.response.stream):
|
||||||
# no streaming:
|
# no streaming:
|
||||||
# we already received the full response from the server and can
|
# we already received the full response from the server and can
|
||||||
# send it to the client straight away.
|
# send it to the client straight away.
|
||||||
self.send_to_client(flow.response)
|
self.send_response(flow.response)
|
||||||
else:
|
else:
|
||||||
# streaming:
|
# streaming:
|
||||||
# First send the headers and then transfer the response
|
# First send the headers and then transfer the response incrementally
|
||||||
# incrementally:
|
self.send_response_headers(flow.response)
|
||||||
h = self.client_protocol._assemble_response_first_line(flow.response)
|
chunks = self.read_response_body(
|
||||||
self.send_to_client(h + "\r\n")
|
flow.response.headers,
|
||||||
h = self.client_protocol._assemble_response_headers(flow.response, preserve_transfer_encoding=True)
|
flow.request.method,
|
||||||
self.send_to_client(h + "\r\n")
|
flow.response.code,
|
||||||
|
max_chunk_size=4096
|
||||||
chunks = self.client_protocol.read_http_body_chunked(
|
|
||||||
flow.response.headers,
|
|
||||||
self.config.body_size_limit,
|
|
||||||
flow.request.method,
|
|
||||||
flow.response.code,
|
|
||||||
False,
|
|
||||||
4096
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if callable(flow.response.stream):
|
if callable(flow.response.stream):
|
||||||
chunks = flow.response.stream(chunks)
|
chunks = flow.response.stream(chunks)
|
||||||
|
self.send_response_body(flow.response, chunks)
|
||||||
for chunk in chunks:
|
|
||||||
for part in chunk:
|
|
||||||
# TODO: That's going to fail.
|
|
||||||
self.send_to_client(part)
|
|
||||||
self.client_conn.wfile.flush()
|
|
||||||
|
|
||||||
flow.response.timestamp_end = utils.timestamp()
|
flow.response.timestamp_end = utils.timestamp()
|
||||||
|
|
||||||
def get_response_from_server(self, flow):
|
def get_response_from_server(self, flow):
|
||||||
def get_response():
|
def get_response():
|
||||||
self.send_to_server(flow.request)
|
self.send_request(flow.request)
|
||||||
flow.response = self.read_from_server(flow.request.method)
|
if self.supports_streaming:
|
||||||
|
flow.response = self.read_response_headers()
|
||||||
|
else:
|
||||||
|
flow.response = self.read_response()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
get_response()
|
get_response()
|
||||||
@ -400,18 +458,15 @@ class HttpLayer(Layer):
|
|||||||
if flow is None or flow == KILL:
|
if flow is None or flow == KILL:
|
||||||
raise Kill()
|
raise Kill()
|
||||||
|
|
||||||
if isinstance(self.ctx, Http2Layer):
|
if self.supports_streaming:
|
||||||
pass # streaming is not implemented for http2 yet.
|
if flow.response.stream:
|
||||||
elif flow.response.stream:
|
flow.response.content = CONTENT_MISSING
|
||||||
flow.response.content = CONTENT_MISSING
|
else:
|
||||||
else:
|
flow.response.content = "".join(self.read_response_body(
|
||||||
flow.response.content = self.server_protocol.read_http_body(
|
flow.response.headers,
|
||||||
flow.response.headers,
|
flow.request.method,
|
||||||
self.config.body_size_limit,
|
flow.response.code
|
||||||
flow.request.method,
|
))
|
||||||
flow.response.code,
|
|
||||||
False
|
|
||||||
)
|
|
||||||
flow.response.timestamp_end = utils.timestamp()
|
flow.response.timestamp_end = utils.timestamp()
|
||||||
|
|
||||||
# no further manipulation of self.server_conn beyond this point
|
# no further manipulation of self.server_conn beyond this point
|
||||||
@ -480,14 +535,14 @@ class HttpLayer(Layer):
|
|||||||
if self.server_conn.tls_established:
|
if self.server_conn.tls_established:
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
|
|
||||||
self.send_to_server(make_connect_request(address))
|
self.send_request(make_connect_request(address))
|
||||||
tls_layer = TlsLayer(self, False, True)
|
tls_layer = TlsLayer(self, False, True)
|
||||||
tls_layer._establish_tls_with_server()
|
tls_layer._establish_tls_with_server()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validate_request(self, request):
|
def validate_request(self, request):
|
||||||
if request.form_in == "absolute" and request.scheme != "http":
|
if request.form_in == "absolute" and request.scheme != "http":
|
||||||
self.send_to_client(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
|
self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
|
||||||
raise HttpException("Invalid request scheme: %s" % request.scheme)
|
raise HttpException("Invalid request scheme: %s" % request.scheme)
|
||||||
|
|
||||||
expected_request_forms = {
|
expected_request_forms = {
|
||||||
@ -501,7 +556,7 @@ class HttpLayer(Layer):
|
|||||||
err_message = "Invalid HTTP request form (expected: %s, got: %s)" % (
|
err_message = "Invalid HTTP request form (expected: %s, got: %s)" % (
|
||||||
" or ".join(allowed_request_forms), request.form_in
|
" or ".join(allowed_request_forms), request.form_in
|
||||||
)
|
)
|
||||||
self.send_to_client(make_error_response(400, err_message))
|
self.send_response(make_error_response(400, err_message))
|
||||||
raise HttpException(err_message)
|
raise HttpException(err_message)
|
||||||
|
|
||||||
if self.mode == "regular":
|
if self.mode == "regular":
|
||||||
@ -512,7 +567,7 @@ class HttpLayer(Layer):
|
|||||||
if self.config.authenticator.authenticate(request.headers):
|
if self.config.authenticator.authenticate(request.headers):
|
||||||
self.config.authenticator.clean(request.headers)
|
self.config.authenticator.clean(request.headers)
|
||||||
else:
|
else:
|
||||||
self.send_to_client(make_error_response(
|
self.send_response(make_error_response(
|
||||||
407,
|
407,
|
||||||
"Proxy Authentication Required",
|
"Proxy Authentication Required",
|
||||||
odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()])
|
odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()])
|
||||||
@ -552,10 +607,7 @@ class RequestReplayThread(threading.Thread):
|
|||||||
if not self.flow.response:
|
if not self.flow.response:
|
||||||
# In all modes, we directly connect to the server displayed
|
# In all modes, we directly connect to the server displayed
|
||||||
if self.config.mode == "upstream":
|
if self.config.mode == "upstream":
|
||||||
# FIXME
|
server_address = self.config.upstream_server.address
|
||||||
server_address = self.config.mode.get_upstream_server(
|
|
||||||
self.flow.client_conn
|
|
||||||
)[2:]
|
|
||||||
server = ServerConnection(server_address)
|
server = ServerConnection(server_address)
|
||||||
server.connect()
|
server.connect()
|
||||||
protocol = HTTP1Protocol(server)
|
protocol = HTTP1Protocol(server)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
def modify(chunks):
|
def modify(chunks):
|
||||||
for prefix, content, suffix in chunks:
|
for chunk in chunks:
|
||||||
yield prefix, content.replace("foo", "bar"), suffix
|
yield chunk.replace("foo", "bar")
|
||||||
|
|
||||||
|
|
||||||
def responseheaders(context, flow):
|
def responseheaders(context, flow):
|
||||||
|
Loading…
Reference in New Issue
Block a user