refactor response length handling

This commit is contained in:
Maximilian Hils 2014-07-21 14:01:24 +02:00
parent 280d9b8625
commit 6bd5df79f8
2 changed files with 107 additions and 175 deletions

View File

@ -1,4 +1,5 @@
import string, urlparse, binascii import string, urlparse, binascii
import sys
import odict, utils import odict, utils
@ -88,14 +89,14 @@ def read_headers(fp):
# We're being liberal in what we accept, here. # We're being liberal in what we accept, here.
if i > 0: if i > 0:
name = line[:i] name = line[:i]
value = line[i+1:].strip() value = line[i + 1:].strip()
ret.append([name, value]) ret.append([name, value])
else: else:
return None return None
return odict.ODictCaseless(ret) return odict.ODictCaseless(ret)
def read_chunked(fp, headers, limit, is_request): def read_chunked(fp, limit, is_request):
""" """
Read a chunked HTTP body. Read a chunked HTTP body.
@ -103,10 +104,9 @@ def read_chunked(fp, headers, limit, is_request):
""" """
# FIXME: Should check if chunked is the final encoding in the headers # FIXME: Should check if chunked is the final encoding in the headers
# http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2.
content = ""
total = 0 total = 0
code = 400 if is_request else 502 code = 400 if is_request else 502
while 1: while True:
line = fp.readline(128) line = fp.readline(128)
if line == "": if line == "":
raise HttpErrorConnClosed(code, "Connection closed prematurely") raise HttpErrorConnClosed(code, "Connection closed prematurely")
@ -114,70 +114,19 @@ def read_chunked(fp, headers, limit, is_request):
try: try:
length = int(line, 16) length = int(line, 16)
except ValueError: except ValueError:
# FIXME: Not strictly correct - this could be from the server, in which raise HttpError(code, "Invalid chunked encoding length: %s" % line)
# case we should send a 502.
raise HttpError(code, "Invalid chunked encoding length: %s"%line)
if not length:
break
total += length total += length
if limit is not None and total > limit: if limit is not None and total > limit:
msg = "HTTP Body too large."\ msg = "HTTP Body too large." \
" Limit is %s, chunked content length was at least %s"%(limit, total) " Limit is %s, chunked content length was at least %s" % (limit, total)
raise HttpError(code, msg) raise HttpError(code, msg)
content += fp.read(length) chunk = fp.read(length)
line = fp.readline(5) suffix = fp.readline(5)
if line != '\r\n': if suffix != '\r\n':
raise HttpError(code, "Malformed chunked body") raise HttpError(code, "Malformed chunked body")
while 1: yield line, chunk, '\r\n'
line = fp.readline() if length == 0:
if line == "": return
raise HttpErrorConnClosed(code, "Connection closed prematurely")
if line == '\r\n' or line == '\n':
break
return content
def read_next_chunk(fp, headers, is_request):
"""
Read next piece of a chunked HTTP body. Returns next piece of
content as a string or None if we hit the end.
"""
# TODO: see and understand the FIXME in read_chunked and
# see if we need to apply here?
content = ""
code = 400 if is_request else 502
line = fp.readline(128)
if line == "":
raise HttpErrorConnClosed(code, "Connection closed prematurely")
try:
length = int(line, 16)
except ValueError:
# TODO: see note in this part of read_chunked()
raise HttpError(code, "Invalid chunked encoding length: %s"%line)
if length > 0:
content += fp.read(length)
print "read content: '%s'" % content
line = fp.readline(5)
if line == '':
raise HttpErrorConnClosed(code, "Connection closed prematurely")
if line != '\r\n':
raise HttpError(code, "Malformed chunked body: '%s' (len=%d)" % (line, length))
if content == "":
content = None # normalize zero length to None, meaning end of chunked stream
return content # return this chunk
def write_chunk(fp, content):
"""
Write a chunk with chunked encoding format, returns True
if there should be more chunks or False if you passed
None, meaning this was the last chunk.
"""
if content == None or content == "":
fp.write("0\r\n\r\n")
return False
fp.write("%x\r\n" % len(content))
fp.write(content)
fp.write("\r\n")
return True
def get_header_tokens(headers, key): def get_header_tokens(headers, key):
@ -307,6 +256,7 @@ def parse_init_http(line):
def connection_close(httpversion, headers): def connection_close(httpversion, headers):
""" """
Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1
Note that a connection should be closed as well if the response has been read until end of the stream.
""" """
# At first, check if we have an explicit Connection header. # At first, check if we have an explicit Connection header.
if "connection" in headers: if "connection" in headers:
@ -335,7 +285,7 @@ def parse_response_line(line):
return (proto, code, msg) return (proto, code, msg)
def read_response(rfile, method, body_size_limit, include_body=True): def read_response(rfile, request_method, body_size_limit, include_body=True):
""" """
Return an (httpversion, code, msg, headers, content) tuple. Return an (httpversion, code, msg, headers, content) tuple.
""" """
@ -346,26 +296,27 @@ def read_response(rfile, method, body_size_limit, include_body=True):
raise HttpErrorConnClosed(502, "Server disconnect.") raise HttpErrorConnClosed(502, "Server disconnect.")
parts = parse_response_line(line) parts = parse_response_line(line)
if not parts: if not parts:
raise HttpError(502, "Invalid server response: %s"%repr(line)) raise HttpError(502, "Invalid server response: %s" % repr(line))
proto, code, msg = parts proto, code, msg = parts
httpversion = parse_http_protocol(proto) httpversion = parse_http_protocol(proto)
if httpversion is None: if httpversion is None:
raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto))
headers = read_headers(rfile) headers = read_headers(rfile)
if headers is None: if headers is None:
raise HttpError(502, "Invalid headers.") raise HttpError(502, "Invalid headers.")
# Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 if include_body:
if method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: content = read_http_body(rfile, headers, body_size_limit, request_method, code, False)
content = ""
elif include_body:
content = read_http_body(rfile, headers, body_size_limit, False)
else: else:
content = None # if include_body==False then a None content means the body should be read separately content = None # if include_body==False then a None content means the body should be read separately
return httpversion, code, msg, headers, content return httpversion, code, msg, headers, content
def read_http_body(rfile, headers, limit, is_request): def read_http_body(*args, **kwargs):
return "".join(content for _, content, _ in read_http_body_chunked(*args, **kwargs))
def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None):
""" """
Read an HTTP message body: Read an HTTP message body:
@ -374,41 +325,69 @@ def read_http_body(rfile, headers, limit, is_request):
limit: Size limit. limit: Size limit.
is_request: True if the body to read belongs to a request, False otherwise is_request: True if the body to read belongs to a request, False otherwise
""" """
if max_chunk_size is None:
max_chunk_size = limit or sys.maxint
expected_size = expected_http_body_size(headers, is_request, request_method, response_code)
if expected_size is None:
if has_chunked_encoding(headers): if has_chunked_encoding(headers):
content = read_chunked(rfile, headers, limit, is_request) # Python 3: yield from
elif "content-length" in headers: for x in read_chunked(rfile, limit, is_request):
try: yield x
l = int(headers["content-length"][0]) else: # pragma: nocover
if l < 0: raise HttpError(400 if is_request else 502, "Content-Length unknown but no chunked encoding")
raise ValueError() elif expected_size >= 0:
except ValueError: if limit is not None and expected_size > limit:
raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) raise HttpError(400 if is_request else 509,
if limit is not None and l > limit: "HTTP Body too large. Limit is %s, content-length was %s" % (limit, expected_size))
raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) bytes_left = expected_size
content = rfile.read(l) while bytes_left:
elif is_request: chunk_size = min(bytes_left, max_chunk_size)
content = "" yield "", rfile.read(chunk_size), ""
bytes_left -= chunk_size
else: else:
content = rfile.read(limit if limit else -1) bytes_left = limit or -1
while bytes_left:
chunk_size = min(bytes_left, max_chunk_size)
content = rfile.read(chunk_size)
if not content:
return
yield "", content, ""
bytes_left -= chunk_size
not_done = rfile.read(1) not_done = rfile.read(1)
if not_done: if not_done:
raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit)
return content
def expected_http_body_size(headers, is_request):
def expected_http_body_size(headers, is_request, request_method, response_code):
""" """
Returns length of body expected or -1 if not Returns the expected body length:
known and we should just read until end of - a positive integer, if the size is known in advance
stream. - None, if the size in unknown in advance (chunked encoding)
- -1, if all data should be read until end of stream.
""" """
# Determine response size according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3
if request_method:
request_method = request_method.upper()
if (not is_request and (
request_method == "HEAD" or
(request_method == "CONNECT" and response_code == 200) or
response_code in [204, 304] or
100 <= response_code <= 199)):
return 0
if has_chunked_encoding(headers):
return None
if "content-length" in headers: if "content-length" in headers:
try: try:
l = int(headers["content-length"][0]) size = int(headers["content-length"][0])
if l < 0: if size < 0:
raise ValueError() raise ValueError()
return l return size
except ValueError: except ValueError:
raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) raise HttpError(400 if is_request else 502, "Invalid content-length header: %s" % headers["content-length"])
elif is_request: if is_request:
return 0 return 0
return -1 return -1

View File

@ -16,79 +16,32 @@ def test_has_chunked_encoding():
def test_read_chunked(): def test_read_chunked():
h = odict.ODictCaseless()
h["transfer-encoding"] = ["chunked"]
s = cStringIO.StringIO("1\r\na\r\n0\r\n") s = cStringIO.StringIO("1\r\na\r\n0\r\n")
tutils.raises("closed prematurely", http.read_chunked, s, None, None, True)
tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True)
s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n")
assert http.read_chunked(s, None, None, True) == "a" assert http.read_http_body(s, h, None, "GET", None, True) == "a"
s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n")
assert http.read_chunked(s, None, None, True) == "a" assert http.read_http_body(s, h, None, "GET", None, True) == "a"
s = cStringIO.StringIO("\r\n") s = cStringIO.StringIO("\r\n")
tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) tutils.raises("closed prematurely", http.read_http_body, s, h, None, "GET", None, True)
s = cStringIO.StringIO("1\r\nfoo") s = cStringIO.StringIO("1\r\nfoo")
tutils.raises("malformed chunked body", http.read_chunked, s, None, None, True) tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True)
s = cStringIO.StringIO("foo\r\nfoo") s = cStringIO.StringIO("foo\r\nfoo")
tutils.raises(http.HttpError, http.read_chunked, s, None, None, True) tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", None, True)
s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
tutils.raises("too large", http.read_chunked, s, None, 2, True) tutils.raises("too large", http.read_http_body, s, h, 2, "GET", None, True)
def test_read_next_chunk():
s = cStringIO.StringIO(
"4\r\n" +
"mitm\r\n" +
"5\r\n" +
"proxy\r\n" +
"e\r\n" +
" in\r\n\r\nchunks.\r\n" +
"0\r\n" +
"\r\n")
assert http.read_next_chunk(s, None, False) == "mitm"
assert http.read_next_chunk(s, None, False) == "proxy"
assert http.read_next_chunk(s, None, False) == " in\r\n\r\nchunks."
assert http.read_next_chunk(s, None, False) == None
s = cStringIO.StringIO("")
tutils.raises("closed prematurely", http.read_next_chunk, s, None, False)
s = cStringIO.StringIO("1\r\na\r\n0\r\n")
http.read_next_chunk(s, None, False)
tutils.raises("closed prematurely", http.read_next_chunk, s, None, False)
s = cStringIO.StringIO("1\r\nfoo")
tutils.raises("malformed chunked body", http.read_next_chunk, s, None, False)
s = cStringIO.StringIO("foo\r\nfoo")
tutils.raises(http.HttpError, http.read_next_chunk, s, None, False)
def test_write_chunk():
expected = ("" +
"4\r\n" +
"mitm\r\n" +
"5\r\n" +
"proxy\r\n" +
"e\r\n" +
" in\r\n\r\nchunks.\r\n" +
"0\r\n" +
"\r\n")
s = cStringIO.StringIO()
http.write_chunk(s, "mitm")
http.write_chunk(s, "proxy")
http.write_chunk(s, " in\r\n\r\nchunks.")
http.write_chunk(s, None)
print len(s.getvalue())
print len(expected)
assert s.getvalue() == expected
def test_connection_close(): def test_connection_close():
h = odict.ODictCaseless() h = odict.ODictCaseless()
assert http.connection_close((1, 0), h) assert http.connection_close((1, 0), h)
@ -114,73 +67,73 @@ def test_get_header_tokens():
def test_read_http_body_request(): def test_read_http_body_request():
h = odict.ODictCaseless() h = odict.ODictCaseless()
r = cStringIO.StringIO("testing") r = cStringIO.StringIO("testing")
assert http.read_http_body(r, h, None, True) == "" assert http.read_http_body(r, h, None, "GET", None, True) == ""
def test_read_http_body_response(): def test_read_http_body_response():
h = odict.ODictCaseless() h = odict.ODictCaseless()
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert http.read_http_body(s, h, None, False) == "testing" assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"
def test_read_http_body(): def test_read_http_body():
# test default case # test default case
h = odict.ODictCaseless() h = odict.ODictCaseless()
h["content-length"] = [7] h["content-length"] = [7]
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert http.read_http_body(s, h, None, False) == "testing" assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"
# test content length: invalid header # test content length: invalid header
h["content-length"] = ["foo"] h["content-length"] = ["foo"]
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False)
# test content length: invalid header #2 # test content length: invalid header #2
h["content-length"] = [-1] h["content-length"] = [-1]
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False)
# test content length: content length > actual content # test content length: content length > actual content
h["content-length"] = [5] h["content-length"] = [5]
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False)
# test content length: content length < actual content # test content length: content length < actual content
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert len(http.read_http_body(s, h, None, False)) == 5 assert len(http.read_http_body(s, h, None, "GET", 200, False)) == 5
# test no content length: limit > actual content # test no content length: limit > actual content
h = odict.ODictCaseless() h = odict.ODictCaseless()
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert len(http.read_http_body(s, h, 100, False)) == 7 assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7
# test no content length: limit < actual content # test no content length: limit < actual content
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False)
# test chunked # test chunked
h = odict.ODictCaseless() h = odict.ODictCaseless()
h["transfer-encoding"] = ["chunked"] h["transfer-encoding"] = ["chunked"]
s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
assert http.read_http_body(s, h, 100, False) == "aaaaa" assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"
def test_expected_http_body_size(): def test_expected_http_body_size():
# gibber in the content-length field # gibber in the content-length field
h = odict.ODictCaseless() h = odict.ODictCaseless()
h["content-length"] = ["foo"] h["content-length"] = ["foo"]
tutils.raises(http.HttpError, http.expected_http_body_size, h, False) tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200)
# negative number in the content-length field # negative number in the content-length field
h = odict.ODictCaseless() h = odict.ODictCaseless()
h["content-length"] = ["-7"] h["content-length"] = ["-7"]
tutils.raises(http.HttpError, http.expected_http_body_size, h, False) tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200)
# explicit length # explicit length
h = odict.ODictCaseless() h = odict.ODictCaseless()
h["content-length"] = ["5"] h["content-length"] = ["5"]
assert http.expected_http_body_size(h, False) == 5 assert http.expected_http_body_size(h, False, "GET", 200) == 5
# no length # no length
h = odict.ODictCaseless() h = odict.ODictCaseless()
assert http.expected_http_body_size(h, False) == -1 assert http.expected_http_body_size(h, False, "GET", 200) == -1
# no length request # no length request
h = odict.ODictCaseless() h = odict.ODictCaseless()
assert http.expected_http_body_size(h, True) == 0 assert http.expected_http_body_size(h, True, "GET", None) == 0
def test_parse_http_protocol(): def test_parse_http_protocol():
assert http.parse_http_protocol("HTTP/1.1") == (1, 1) assert http.parse_http_protocol("HTTP/1.1") == (1, 1)