Localise client connection object manipulation.

This simplifies the call signature for a bunch of functions.
This commit is contained in:
Aldo Cortesi 2012-06-10 16:49:59 +12:00
parent 1f659948cd
commit d60fa9918b
2 changed files with 52 additions and 27 deletions

View File

@ -104,10 +104,16 @@ def read_chunked(fp, limit):
return content return content
def read_http_body(rfile, client_conn, headers, all, limit): def has_chunked_encoding(headers):
if 'transfer-encoding' in headers: for i in headers["transfer-encoding"]:
if not ",".join(headers["transfer-encoding"]).lower() == "chunked": for j in i.split(","):
raise IOError('Invalid transfer-encoding') if j.lower() == "chunked":
return True
return False
def read_http_body(rfile, headers, all, limit):
if has_chunked_encoding(headers):
content = read_chunked(rfile, limit) content = read_chunked(rfile, limit)
elif "content-length" in headers: elif "content-length" in headers:
try: try:
@ -121,7 +127,6 @@ def read_http_body(rfile, client_conn, headers, all, limit):
content = rfile.read(l) content = rfile.read(l)
elif all: elif all:
content = rfile.read(limit if limit else None) content = rfile.read(limit if limit else None)
client_conn.close = True
else: else:
content = "" content = ""
return content return content
@ -185,10 +190,9 @@ def parse_init_http(line):
return method, url, httpversion return method, url, httpversion
def should_connection_close(httpversion, headers): def request_connection_close(httpversion, headers):
""" """
Checks the HTTP version and headers to see if this connection should be Checks the request to see if the client connection should be closed.
closed.
""" """
if "connection" in headers: if "connection" in headers:
for value in ",".join(headers['connection']).split(","): for value in ",".join(headers['connection']).split(","):
@ -203,7 +207,18 @@ def should_connection_close(httpversion, headers):
return True return True
def read_http_body_request(rfile, wfile, client_conn, headers, httpversion, limit): def response_connection_close(httpversion, headers):
"""
Checks the response to see if the client connection should be closed.
"""
if request_connection_close(httpversion, headers):
return True
elif not has_chunked_encoding(headers) and "content-length" in headers:
return True
return False
def read_http_body_request(rfile, wfile, headers, httpversion, limit):
if "expect" in headers: if "expect" in headers:
# FIXME: Should be forwarded upstream # FIXME: Should be forwarded upstream
expect = ",".join(headers['expect']) expect = ",".join(headers['expect'])
@ -212,7 +227,7 @@ def read_http_body_request(rfile, wfile, client_conn, headers, httpversion, limi
wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
wfile.write('\r\n') wfile.write('\r\n')
del headers['expect'] del headers['expect']
return read_http_body(rfile, client_conn, headers, False, limit) return read_http_body(rfile, headers, False, limit)
class FileLike: class FileLike:
@ -335,7 +350,7 @@ class ServerConnection:
if request.method == "HEAD" or code == 204 or code == 304: if request.method == "HEAD" or code == 204 or code == 304:
content = "" content = ""
else: else:
content = read_http_body(self.rfile, self, headers, True, self.config.body_size_limit) content = read_http_body(self.rfile, headers, True, self.config.body_size_limit)
return flow.Response(request, httpversion, code, msg, headers, content, self.cert) return flow.Response(request, httpversion, code, msg, headers, content, self.cert)
def terminate(self): def terminate(self):
@ -413,7 +428,13 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
if response is None: if response is None:
return return
self.send_response(response) self.send_response(response)
if should_connection_close(request.httpversion, request.headers): if request_connection_close(request.httpversion, request.headers):
return
# We could keep the client connection when the server
# connection needs to go away. However, we want to mimic
# behaviour as closely as possible to the client, so we
# disconnect.
if response_connection_close(response.httpversion, response.headers):
return return
except IOError, v: except IOError, v:
cc.connection_error = v cc.connection_error = v
@ -467,7 +488,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
method, path, httpversion = parse_init_http(line) method, path, httpversion = parse_init_http(line)
headers = read_headers(self.rfile) headers = read_headers(self.rfile)
content = read_http_body_request( content = read_http_body_request(
self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
) )
return flow.Request(client_conn, httpversion, host, port, "http", method, path, headers, content) return flow.Request(client_conn, httpversion, host, port, "http", method, path, headers, content)
else: else:
@ -495,14 +516,14 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
method, path, httpversion = parse_init_http(line) method, path, httpversion = parse_init_http(line)
headers = read_headers(self.rfile) headers = read_headers(self.rfile)
content = read_http_body_request( content = read_http_body_request(
self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
) )
return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content) return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content)
else: else:
method, scheme, host, port, path, httpversion = parse_init_proxy(line) method, scheme, host, port, path, httpversion = parse_init_proxy(line)
headers = read_headers(self.rfile) headers = read_headers(self.rfile)
content = read_http_body_request( content = read_http_body_request(
self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
) )
return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content) return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content)

View File

@ -4,7 +4,12 @@ import libpry
from libmproxy import proxy, flow from libmproxy import proxy, flow
import tutils import tutils
class Dummy: pass
def test_has_chunked_encoding():
h = flow.ODictCaseless()
assert not proxy.has_chunked_encoding(h)
h["transfer-encoding"] = ["chunked"]
assert proxy.has_chunked_encoding(h)
def test_read_chunked(): def test_read_chunked():
@ -24,36 +29,35 @@ def test_read_chunked():
tutils.raises(proxy.ProxyError, proxy.read_chunked, s, None) tutils.raises(proxy.ProxyError, proxy.read_chunked, s, None)
def test_should_connection_close(): def test_request_connection_close():
h = flow.ODictCaseless() h = flow.ODictCaseless()
assert proxy.should_connection_close((1, 0), h) assert proxy.request_connection_close((1, 0), h)
assert not proxy.should_connection_close((1, 1), h) assert not proxy.request_connection_close((1, 1), h)
h["connection"] = ["keep-alive"] h["connection"] = ["keep-alive"]
assert not proxy.should_connection_close((1, 1), h) assert not proxy.request_connection_close((1, 1), h)
def test_read_http_body(): def test_read_http_body():
d = Dummy()
h = flow.ODict() h = flow.ODict()
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert proxy.read_http_body(s, d, h, False, None) == "" assert proxy.read_http_body(s, h, False, None) == ""
h["content-length"] = ["foo"] h["content-length"] = ["foo"]
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
tutils.raises(proxy.ProxyError, proxy.read_http_body, s, d, h, False, None) tutils.raises(proxy.ProxyError, proxy.read_http_body, s, h, False, None)
h["content-length"] = [5] h["content-length"] = [5]
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert len(proxy.read_http_body(s, d, h, False, None)) == 5 assert len(proxy.read_http_body(s, h, False, None)) == 5
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
tutils.raises(proxy.ProxyError, proxy.read_http_body, s, d, h, False, 4) tutils.raises(proxy.ProxyError, proxy.read_http_body, s, h, False, 4)
h = flow.ODict() h = flow.ODict()
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert len(proxy.read_http_body(s, d, h, True, 4)) == 4 assert len(proxy.read_http_body(s, h, True, 4)) == 4
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert len(proxy.read_http_body(s, d, h, True, 100)) == 7 assert len(proxy.read_http_body(s, h, True, 100)) == 7
class TestFileLike: class TestFileLike: