fix streaming

This commit is contained in:
Maximilian Hils 2015-08-29 14:28:11 +02:00
parent 2dfba2105b
commit 63844df343
2 changed files with 124 additions and 72 deletions

View File

@ -25,32 +25,101 @@ from netlib.http.http2 import HTTP2Protocol
# TODO: The HTTP2 layer is missing multiplexing, which requires a major rewrite.
class Http1Layer(Layer):
class _HttpLayer(Layer):
supports_streaming = False
def read_request(self):
raise NotImplementedError()
def send_request(self, request):
raise NotImplementedError()
def read_response(self, request_method):
raise NotImplementedError()
def send_response(self, response):
raise NotImplementedError()
class _StreamingHttpLayer(_HttpLayer):
supports_streaming = True
def read_response_headers(self):
raise NotImplementedError
def read_response_body(self, headers, request_method, response_code, max_chunk_size=None):
raise NotImplementedError()
yield "this is a generator"
def send_response_headers(self, response):
raise NotImplementedError
def send_response_body(self, response, chunks):
raise NotImplementedError()
class Http1Layer(_StreamingHttpLayer):
def __init__(self, ctx, mode):
super(Http1Layer, self).__init__(ctx)
self.mode = mode
self.client_protocol = HTTP1Protocol(self.client_conn)
self.server_protocol = HTTP1Protocol(self.server_conn)
def read_from_client(self):
def read_request(self):
return HTTPRequest.from_protocol(
self.client_protocol,
body_size_limit=self.config.body_size_limit
)
def read_from_server(self, request_method):
def send_request(self, request):
self.server_conn.send(self.server_protocol.assemble(request))
def read_response(self, request_method):
return HTTPResponse.from_protocol(
self.server_protocol,
request_method,
request_method=request_method,
body_size_limit=self.config.body_size_limit,
include_body=False,
include_body=True
)
def send_to_client(self, message):
self.client_conn.send(self.client_protocol.assemble(message))
def send_response(self, response):
self.client_conn.send(self.client_protocol.assemble(response))
def send_to_server(self, message):
self.server_conn.send(self.server_protocol.assemble(message))
def read_response_headers(self):
return HTTPResponse.from_protocol(
self.server_protocol,
request_method=None, # does not matter if we don't read the body.
body_size_limit=self.config.body_size_limit,
include_body=False
)
def read_response_body(self, headers, request_method, response_code, max_chunk_size=None):
return self.server_protocol.read_http_body_chunked(
headers,
self.config.body_size_limit,
request_method,
response_code,
False,
max_chunk_size
)
def send_response_headers(self, response):
h = self.client_protocol._assemble_response_first_line(response)
self.client_conn.wfile.write(h+"\r\n")
h = self.client_protocol._assemble_response_headers(
response,
preserve_transfer_encoding=True
)
self.client_conn.send(h+"\r\n")
def send_response_body(self, response, chunks):
if self.client_protocol.has_chunked_encoding(response.headers):
chunks = (
"%d\r\n%s\r\n" % (len(chunk), chunk)
for chunk in chunks
)
for chunk in chunks:
self.client_conn.send(chunk)
def connect(self):
self.ctx.connect()
@ -69,14 +138,14 @@ class Http1Layer(Layer):
layer()
class Http2Layer(Layer):
class Http2Layer(_HttpLayer):
def __init__(self, ctx, mode):
super(Http2Layer, self).__init__(ctx)
self.mode = mode
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
def read_from_client(self):
def read_request(self):
request = HTTPRequest.from_protocol(
self.client_protocol,
body_size_limit=self.config.body_size_limit
@ -84,23 +153,23 @@ class Http2Layer(Layer):
self._stream_id = request.stream_id
return request
def read_from_server(self, request_method):
def send_request(self, message):
# TODO: implement flow control and WINDOW_UPDATE frames
self.server_conn.send(self.server_protocol.assemble(message))
def read_response(self, request_method):
return HTTPResponse.from_protocol(
self.server_protocol,
request_method,
request_method=request_method,
body_size_limit=self.config.body_size_limit,
include_body=True,
stream_id=self._stream_id
)
def send_to_client(self, message):
def send_response(self, message):
# TODO: implement flow control and WINDOW_UPDATE frames
self.client_conn.send(self.client_protocol.assemble(message))
def send_to_server(self, message):
# TODO: implement flow control and WINDOW_UPDATE frames
self.server_conn.send(self.server_protocol.assemble(message))
def connect(self):
self.ctx.connect()
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
@ -122,7 +191,7 @@ class Http2Layer(Layer):
layer()
def handle_unexpected_frame(self, frm):
print(frm.human_readable())
self.log("Unexpected HTTP2 Frame: %s" % frm.human_readable(), "info")
def make_error_response(status_code, message, headers=None):
@ -204,13 +273,13 @@ class UpstreamConnectLayer(Layer):
def connect(self):
if not self.server_conn:
self.ctx.connect()
self.send_to_server(self.connect_request)
self.send_request(self.connect_request)
else:
pass # swallow the message
def reconnect(self):
self.ctx.reconnect()
self.send_to_server(self.connect_request)
self.send_request(self.connect_request)
def set_server(self, address, server_tls=None, sni=None, depth=1):
if depth == 1:
@ -240,7 +309,7 @@ class HttpLayer(Layer):
flow = HTTPFlow(self.client_conn, self.server_conn, live=self)
try:
request = self.read_from_client()
request = self.read_request()
except tcp.NetLibError:
# don't throw an error for disconnects that happen
# before/between requests.
@ -280,7 +349,7 @@ class HttpLayer(Layer):
except (HttpErrorConnClosed, NetLibError, HttpError, ProtocolException) as e:
try:
self.send_to_client(make_error_response(
self.send_response(make_error_response(
getattr(e, "code", 502),
repr(e)
))
@ -295,7 +364,7 @@ class HttpLayer(Layer):
def handle_regular_mode_connect(self, request):
self.set_server((request.host, request.port))
self.send_to_client(make_connect_response(request.httpversion))
self.send_response(make_connect_response(request.httpversion))
layer = self.ctx.next_layer(self)
layer()
@ -334,44 +403,33 @@ class HttpLayer(Layer):
return close_connection
def send_response_to_client(self, flow):
if not flow.response.stream:
if not (self.supports_streaming and flow.response.stream):
# no streaming:
# we already received the full response from the server and can
# send it to the client straight away.
self.send_to_client(flow.response)
self.send_response(flow.response)
else:
# streaming:
# First send the headers and then transfer the response
# incrementally:
h = self.client_protocol._assemble_response_first_line(flow.response)
self.send_to_client(h + "\r\n")
h = self.client_protocol._assemble_response_headers(flow.response, preserve_transfer_encoding=True)
self.send_to_client(h + "\r\n")
chunks = self.client_protocol.read_http_body_chunked(
flow.response.headers,
self.config.body_size_limit,
flow.request.method,
flow.response.code,
False,
4096
# First send the headers and then transfer the response incrementally
self.send_response_headers(flow.response)
chunks = self.read_response_body(
flow.response.headers,
flow.request.method,
flow.response.code,
max_chunk_size=4096
)
if callable(flow.response.stream):
chunks = flow.response.stream(chunks)
for chunk in chunks:
for part in chunk:
# TODO: That's going to fail.
self.send_to_client(part)
self.client_conn.wfile.flush()
self.send_response_body(flow.response, chunks)
flow.response.timestamp_end = utils.timestamp()
def get_response_from_server(self, flow):
def get_response():
self.send_to_server(flow.request)
flow.response = self.read_from_server(flow.request.method)
self.send_request(flow.request)
if self.supports_streaming:
flow.response = self.read_response_headers()
else:
flow.response = self.read_response()
try:
get_response()
@ -400,18 +458,15 @@ class HttpLayer(Layer):
if flow is None or flow == KILL:
raise Kill()
if isinstance(self.ctx, Http2Layer):
pass # streaming is not implemented for http2 yet.
elif flow.response.stream:
flow.response.content = CONTENT_MISSING
else:
flow.response.content = self.server_protocol.read_http_body(
flow.response.headers,
self.config.body_size_limit,
flow.request.method,
flow.response.code,
False
)
if self.supports_streaming:
if flow.response.stream:
flow.response.content = CONTENT_MISSING
else:
flow.response.content = "".join(self.read_response_body(
flow.response.headers,
flow.request.method,
flow.response.code
))
flow.response.timestamp_end = utils.timestamp()
# no further manipulation of self.server_conn beyond this point
@ -480,14 +535,14 @@ class HttpLayer(Layer):
if self.server_conn.tls_established:
self.reconnect()
self.send_to_server(make_connect_request(address))
self.send_request(make_connect_request(address))
tls_layer = TlsLayer(self, False, True)
tls_layer._establish_tls_with_server()
"""
def validate_request(self, request):
if request.form_in == "absolute" and request.scheme != "http":
self.send_to_client(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
raise HttpException("Invalid request scheme: %s" % request.scheme)
expected_request_forms = {
@ -501,7 +556,7 @@ class HttpLayer(Layer):
err_message = "Invalid HTTP request form (expected: %s, got: %s)" % (
" or ".join(allowed_request_forms), request.form_in
)
self.send_to_client(make_error_response(400, err_message))
self.send_response(make_error_response(400, err_message))
raise HttpException(err_message)
if self.mode == "regular":
@ -512,7 +567,7 @@ class HttpLayer(Layer):
if self.config.authenticator.authenticate(request.headers):
self.config.authenticator.clean(request.headers)
else:
self.send_to_client(make_error_response(
self.send_response(make_error_response(
407,
"Proxy Authentication Required",
odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()])
@ -552,10 +607,7 @@ class RequestReplayThread(threading.Thread):
if not self.flow.response:
# In all modes, we directly connect to the server displayed
if self.config.mode == "upstream":
# FIXME
server_address = self.config.mode.get_upstream_server(
self.flow.client_conn
)[2:]
server_address = self.config.upstream_server.address
server = ServerConnection(server_address)
server.connect()
protocol = HTTP1Protocol(server)

View File

@ -1,6 +1,6 @@
def modify(chunks):
for prefix, content, suffix in chunks:
yield prefix, content.replace("foo", "bar"), suffix
for chunk in chunks:
yield chunk.replace("foo", "bar")
def responseheaders(context, flow):