fix streaming

This commit is contained in:
Maximilian Hils 2015-08-29 14:28:11 +02:00
parent 2dfba2105b
commit 63844df343
2 changed files with 124 additions and 72 deletions

View File

@ -25,32 +25,101 @@ from netlib.http.http2 import HTTP2Protocol
# TODO: The HTTP2 layer is missing multiplexing, which requires a major rewrite. # TODO: The HTTP2 layer is missing multiplexing, which requires a major rewrite.
class Http1Layer(Layer): class _HttpLayer(Layer):
supports_streaming = False
def read_request(self):
raise NotImplementedError()
def send_request(self, request):
raise NotImplementedError()
def read_response(self, request_method):
raise NotImplementedError()
def send_response(self, response):
raise NotImplementedError()
class _StreamingHttpLayer(_HttpLayer):
supports_streaming = True
def read_response_headers(self):
raise NotImplementedError
def read_response_body(self, headers, request_method, response_code, max_chunk_size=None):
raise NotImplementedError()
yield "this is a generator"
def send_response_headers(self, response):
raise NotImplementedError
def send_response_body(self, response, chunks):
raise NotImplementedError()
class Http1Layer(_StreamingHttpLayer):
def __init__(self, ctx, mode): def __init__(self, ctx, mode):
super(Http1Layer, self).__init__(ctx) super(Http1Layer, self).__init__(ctx)
self.mode = mode self.mode = mode
self.client_protocol = HTTP1Protocol(self.client_conn) self.client_protocol = HTTP1Protocol(self.client_conn)
self.server_protocol = HTTP1Protocol(self.server_conn) self.server_protocol = HTTP1Protocol(self.server_conn)
def read_from_client(self): def read_request(self):
return HTTPRequest.from_protocol( return HTTPRequest.from_protocol(
self.client_protocol, self.client_protocol,
body_size_limit=self.config.body_size_limit body_size_limit=self.config.body_size_limit
) )
def read_from_server(self, request_method): def send_request(self, request):
self.server_conn.send(self.server_protocol.assemble(request))
def read_response(self, request_method):
return HTTPResponse.from_protocol( return HTTPResponse.from_protocol(
self.server_protocol, self.server_protocol,
request_method, request_method=request_method,
body_size_limit=self.config.body_size_limit, body_size_limit=self.config.body_size_limit,
include_body=False, include_body=True
) )
def send_to_client(self, message): def send_response(self, response):
self.client_conn.send(self.client_protocol.assemble(message)) self.client_conn.send(self.client_protocol.assemble(response))
def send_to_server(self, message): def read_response_headers(self):
self.server_conn.send(self.server_protocol.assemble(message)) return HTTPResponse.from_protocol(
self.server_protocol,
request_method=None, # does not matter if we don't read the body.
body_size_limit=self.config.body_size_limit,
include_body=False
)
def read_response_body(self, headers, request_method, response_code, max_chunk_size=None):
return self.server_protocol.read_http_body_chunked(
headers,
self.config.body_size_limit,
request_method,
response_code,
False,
max_chunk_size
)
def send_response_headers(self, response):
h = self.client_protocol._assemble_response_first_line(response)
self.client_conn.wfile.write(h+"\r\n")
h = self.client_protocol._assemble_response_headers(
response,
preserve_transfer_encoding=True
)
self.client_conn.send(h+"\r\n")
def send_response_body(self, response, chunks):
if self.client_protocol.has_chunked_encoding(response.headers):
chunks = (
"%d\r\n%s\r\n" % (len(chunk), chunk)
for chunk in chunks
)
for chunk in chunks:
self.client_conn.send(chunk)
def connect(self): def connect(self):
self.ctx.connect() self.ctx.connect()
@ -69,14 +138,14 @@ class Http1Layer(Layer):
layer() layer()
class Http2Layer(Layer): class Http2Layer(_HttpLayer):
def __init__(self, ctx, mode): def __init__(self, ctx, mode):
super(Http2Layer, self).__init__(ctx) super(Http2Layer, self).__init__(ctx)
self.mode = mode self.mode = mode
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame) self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
def read_from_client(self): def read_request(self):
request = HTTPRequest.from_protocol( request = HTTPRequest.from_protocol(
self.client_protocol, self.client_protocol,
body_size_limit=self.config.body_size_limit body_size_limit=self.config.body_size_limit
@ -84,23 +153,23 @@ class Http2Layer(Layer):
self._stream_id = request.stream_id self._stream_id = request.stream_id
return request return request
def read_from_server(self, request_method): def send_request(self, message):
# TODO: implement flow control and WINDOW_UPDATE frames
self.server_conn.send(self.server_protocol.assemble(message))
def read_response(self, request_method):
return HTTPResponse.from_protocol( return HTTPResponse.from_protocol(
self.server_protocol, self.server_protocol,
request_method, request_method=request_method,
body_size_limit=self.config.body_size_limit, body_size_limit=self.config.body_size_limit,
include_body=True, include_body=True,
stream_id=self._stream_id stream_id=self._stream_id
) )
def send_to_client(self, message): def send_response(self, message):
# TODO: implement flow control and WINDOW_UPDATE frames # TODO: implement flow control and WINDOW_UPDATE frames
self.client_conn.send(self.client_protocol.assemble(message)) self.client_conn.send(self.client_protocol.assemble(message))
def send_to_server(self, message):
# TODO: implement flow control and WINDOW_UPDATE frames
self.server_conn.send(self.server_protocol.assemble(message))
def connect(self): def connect(self):
self.ctx.connect() self.ctx.connect()
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
@ -122,7 +191,7 @@ class Http2Layer(Layer):
layer() layer()
def handle_unexpected_frame(self, frm): def handle_unexpected_frame(self, frm):
print(frm.human_readable()) self.log("Unexpected HTTP2 Frame: %s" % frm.human_readable(), "info")
def make_error_response(status_code, message, headers=None): def make_error_response(status_code, message, headers=None):
@ -204,13 +273,13 @@ class UpstreamConnectLayer(Layer):
def connect(self): def connect(self):
if not self.server_conn: if not self.server_conn:
self.ctx.connect() self.ctx.connect()
self.send_to_server(self.connect_request) self.send_request(self.connect_request)
else: else:
pass # swallow the message pass # swallow the message
def reconnect(self): def reconnect(self):
self.ctx.reconnect() self.ctx.reconnect()
self.send_to_server(self.connect_request) self.send_request(self.connect_request)
def set_server(self, address, server_tls=None, sni=None, depth=1): def set_server(self, address, server_tls=None, sni=None, depth=1):
if depth == 1: if depth == 1:
@ -240,7 +309,7 @@ class HttpLayer(Layer):
flow = HTTPFlow(self.client_conn, self.server_conn, live=self) flow = HTTPFlow(self.client_conn, self.server_conn, live=self)
try: try:
request = self.read_from_client() request = self.read_request()
except tcp.NetLibError: except tcp.NetLibError:
# don't throw an error for disconnects that happen # don't throw an error for disconnects that happen
# before/between requests. # before/between requests.
@ -280,7 +349,7 @@ class HttpLayer(Layer):
except (HttpErrorConnClosed, NetLibError, HttpError, ProtocolException) as e: except (HttpErrorConnClosed, NetLibError, HttpError, ProtocolException) as e:
try: try:
self.send_to_client(make_error_response( self.send_response(make_error_response(
getattr(e, "code", 502), getattr(e, "code", 502),
repr(e) repr(e)
)) ))
@ -295,7 +364,7 @@ class HttpLayer(Layer):
def handle_regular_mode_connect(self, request): def handle_regular_mode_connect(self, request):
self.set_server((request.host, request.port)) self.set_server((request.host, request.port))
self.send_to_client(make_connect_response(request.httpversion)) self.send_response(make_connect_response(request.httpversion))
layer = self.ctx.next_layer(self) layer = self.ctx.next_layer(self)
layer() layer()
@ -334,44 +403,33 @@ class HttpLayer(Layer):
return close_connection return close_connection
def send_response_to_client(self, flow): def send_response_to_client(self, flow):
if not flow.response.stream: if not (self.supports_streaming and flow.response.stream):
# no streaming: # no streaming:
# we already received the full response from the server and can # we already received the full response from the server and can
# send it to the client straight away. # send it to the client straight away.
self.send_to_client(flow.response) self.send_response(flow.response)
else: else:
# streaming: # streaming:
# First send the headers and then transfer the response # First send the headers and then transfer the response incrementally
# incrementally: self.send_response_headers(flow.response)
h = self.client_protocol._assemble_response_first_line(flow.response) chunks = self.read_response_body(
self.send_to_client(h + "\r\n") flow.response.headers,
h = self.client_protocol._assemble_response_headers(flow.response, preserve_transfer_encoding=True) flow.request.method,
self.send_to_client(h + "\r\n") flow.response.code,
max_chunk_size=4096
chunks = self.client_protocol.read_http_body_chunked(
flow.response.headers,
self.config.body_size_limit,
flow.request.method,
flow.response.code,
False,
4096
) )
if callable(flow.response.stream): if callable(flow.response.stream):
chunks = flow.response.stream(chunks) chunks = flow.response.stream(chunks)
self.send_response_body(flow.response, chunks)
for chunk in chunks:
for part in chunk:
# TODO: That's going to fail.
self.send_to_client(part)
self.client_conn.wfile.flush()
flow.response.timestamp_end = utils.timestamp() flow.response.timestamp_end = utils.timestamp()
def get_response_from_server(self, flow): def get_response_from_server(self, flow):
def get_response(): def get_response():
self.send_to_server(flow.request) self.send_request(flow.request)
flow.response = self.read_from_server(flow.request.method) if self.supports_streaming:
flow.response = self.read_response_headers()
else:
flow.response = self.read_response()
try: try:
get_response() get_response()
@ -400,18 +458,15 @@ class HttpLayer(Layer):
if flow is None or flow == KILL: if flow is None or flow == KILL:
raise Kill() raise Kill()
if isinstance(self.ctx, Http2Layer): if self.supports_streaming:
pass # streaming is not implemented for http2 yet. if flow.response.stream:
elif flow.response.stream: flow.response.content = CONTENT_MISSING
flow.response.content = CONTENT_MISSING else:
else: flow.response.content = "".join(self.read_response_body(
flow.response.content = self.server_protocol.read_http_body( flow.response.headers,
flow.response.headers, flow.request.method,
self.config.body_size_limit, flow.response.code
flow.request.method, ))
flow.response.code,
False
)
flow.response.timestamp_end = utils.timestamp() flow.response.timestamp_end = utils.timestamp()
# no further manipulation of self.server_conn beyond this point # no further manipulation of self.server_conn beyond this point
@ -480,14 +535,14 @@ class HttpLayer(Layer):
if self.server_conn.tls_established: if self.server_conn.tls_established:
self.reconnect() self.reconnect()
self.send_to_server(make_connect_request(address)) self.send_request(make_connect_request(address))
tls_layer = TlsLayer(self, False, True) tls_layer = TlsLayer(self, False, True)
tls_layer._establish_tls_with_server() tls_layer._establish_tls_with_server()
""" """
def validate_request(self, request): def validate_request(self, request):
if request.form_in == "absolute" and request.scheme != "http": if request.form_in == "absolute" and request.scheme != "http":
self.send_to_client(make_error_response(400, "Invalid request scheme: %s" % request.scheme)) self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
raise HttpException("Invalid request scheme: %s" % request.scheme) raise HttpException("Invalid request scheme: %s" % request.scheme)
expected_request_forms = { expected_request_forms = {
@ -501,7 +556,7 @@ class HttpLayer(Layer):
err_message = "Invalid HTTP request form (expected: %s, got: %s)" % ( err_message = "Invalid HTTP request form (expected: %s, got: %s)" % (
" or ".join(allowed_request_forms), request.form_in " or ".join(allowed_request_forms), request.form_in
) )
self.send_to_client(make_error_response(400, err_message)) self.send_response(make_error_response(400, err_message))
raise HttpException(err_message) raise HttpException(err_message)
if self.mode == "regular": if self.mode == "regular":
@ -512,7 +567,7 @@ class HttpLayer(Layer):
if self.config.authenticator.authenticate(request.headers): if self.config.authenticator.authenticate(request.headers):
self.config.authenticator.clean(request.headers) self.config.authenticator.clean(request.headers)
else: else:
self.send_to_client(make_error_response( self.send_response(make_error_response(
407, 407,
"Proxy Authentication Required", "Proxy Authentication Required",
odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()]) odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()])
@ -552,10 +607,7 @@ class RequestReplayThread(threading.Thread):
if not self.flow.response: if not self.flow.response:
# In all modes, we directly connect to the server displayed # In all modes, we directly connect to the server displayed
if self.config.mode == "upstream": if self.config.mode == "upstream":
# FIXME server_address = self.config.upstream_server.address
server_address = self.config.mode.get_upstream_server(
self.flow.client_conn
)[2:]
server = ServerConnection(server_address) server = ServerConnection(server_address)
server.connect() server.connect()
protocol = HTTP1Protocol(server) protocol = HTTP1Protocol(server)

View File

@ -1,6 +1,6 @@
def modify(chunks): def modify(chunks):
for prefix, content, suffix in chunks: for chunk in chunks:
yield prefix, content.replace("foo", "bar"), suffix yield chunk.replace("foo", "bar")
def responseheaders(context, flow): def responseheaders(context, flow):