read_headers: handle some crashes, return None on invalid data.

This commit is contained in:
Aldo Cortesi 2012-07-30 12:50:35 +12:00
parent eafa5566c2
commit 1c21a28e64
2 changed files with 36 additions and 14 deletions

View File

@ -36,8 +36,8 @@ def parse_url(url):
def read_headers(fp):
"""
Read a set of headers from a file pointer. Stop once a blank line
is reached. Return a ODictCaseless object.
Read a set of headers from a file pointer. Stop once a blank line is
reached. Return a ODictCaseless object, or None if headers are invalid.
"""
ret = []
name = ''
@ -46,6 +46,8 @@ def read_headers(fp):
if not line or line == '\r\n' or line == '\n':
break
if line[0] in ' \t':
if not ret:
return None
# continued header
ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
else:
@ -55,6 +57,8 @@ def read_headers(fp):
name = line[:i]
value = line[i+1:].strip()
ret.append([name, value])
else:
return None
return odict.ODictCaseless(ret)
@ -282,6 +286,8 @@ def read_response(rfile, method, body_size_limit):
except ValueError:
raise HttpError(502, "Invalid server response: %s"%repr(line))
headers = read_headers(rfile)
if headers is None:
raise HttpError(502, "Invalid headers.")
if code >= 100 and code <= 199:
return read_response(rfile, method, body_size_limit)
if method == "HEAD" or code == 204 or code == 304:

View File

@ -169,16 +169,20 @@ def test_parse_init_http():
class TestReadHeaders:
def _read(self, data, verbatim=False):
if not verbatim:
data = textwrap.dedent(data)
data = data.strip()
s = cStringIO.StringIO(data)
return http.read_headers(s)
def test_read_simple(self):
data = """
Header: one
Header2: two
\r\n
"""
data = textwrap.dedent(data)
data = data.strip()
s = cStringIO.StringIO(data)
h = http.read_headers(s)
h = self._read(data)
assert h.lst == [["Header", "one"], ["Header2", "two"]]
def test_read_multi(self):
@ -187,10 +191,7 @@ class TestReadHeaders:
Header: two
\r\n
"""
data = textwrap.dedent(data)
data = data.strip()
s = cStringIO.StringIO(data)
h = http.read_headers(s)
h = self._read(data)
assert h.lst == [["Header", "one"], ["Header", "two"]]
def test_read_continued(self):
@ -200,12 +201,19 @@ class TestReadHeaders:
Header2: three
\r\n
"""
data = textwrap.dedent(data)
data = data.strip()
s = cStringIO.StringIO(data)
h = http.read_headers(s)
h = self._read(data)
assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]]
def test_read_continued_err(self):
data = "\tfoo: bar\r\n"
assert self._read(data, True) is None
def test_read_err(self):
data = """
foo
"""
assert self._read(data) is None
def test_read_response():
def tst(data, method, limit):
@ -248,6 +256,14 @@ def test_read_response():
assert tst(data, "GET", None)[4] == 'foo'
assert tst(data, "HEAD", None)[4] == ''
data = """
HTTP/1.1 200 OK
\tContent-Length: 3
foo
"""
tutils.raises("invalid headers", tst, data, "GET", None)
def test_parse_url():
assert not http.parse_url("")