read_headers: handle some crashes, return None on invalid data.

This commit is contained in:
Aldo Cortesi 2012-07-30 12:50:35 +12:00
parent eafa5566c2
commit 1c21a28e64
2 changed files with 36 additions and 14 deletions

View File

@ -36,8 +36,8 @@ def parse_url(url):
def read_headers(fp): def read_headers(fp):
""" """
Read a set of headers from a file pointer. Stop once a blank line Read a set of headers from a file pointer. Stop once a blank line is
is reached. Return a ODictCaseless object. reached. Return a ODictCaseless object, or None if headers are invalid.
""" """
ret = [] ret = []
name = '' name = ''
@ -46,6 +46,8 @@ def read_headers(fp):
if not line or line == '\r\n' or line == '\n': if not line or line == '\r\n' or line == '\n':
break break
if line[0] in ' \t': if line[0] in ' \t':
if not ret:
return None
# continued header # continued header
ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
else: else:
@ -55,6 +57,8 @@ def read_headers(fp):
name = line[:i] name = line[:i]
value = line[i+1:].strip() value = line[i+1:].strip()
ret.append([name, value]) ret.append([name, value])
else:
return None
return odict.ODictCaseless(ret) return odict.ODictCaseless(ret)
@ -282,6 +286,8 @@ def read_response(rfile, method, body_size_limit):
except ValueError: except ValueError:
raise HttpError(502, "Invalid server response: %s"%repr(line)) raise HttpError(502, "Invalid server response: %s"%repr(line))
headers = read_headers(rfile) headers = read_headers(rfile)
if headers is None:
raise HttpError(502, "Invalid headers.")
if code >= 100 and code <= 199: if code >= 100 and code <= 199:
return read_response(rfile, method, body_size_limit) return read_response(rfile, method, body_size_limit)
if method == "HEAD" or code == 204 or code == 304: if method == "HEAD" or code == 204 or code == 304:

View File

@ -169,16 +169,20 @@ def test_parse_init_http():
class TestReadHeaders: class TestReadHeaders:
def _read(self, data, verbatim=False):
if not verbatim:
data = textwrap.dedent(data)
data = data.strip()
s = cStringIO.StringIO(data)
return http.read_headers(s)
def test_read_simple(self): def test_read_simple(self):
data = """ data = """
Header: one Header: one
Header2: two Header2: two
\r\n \r\n
""" """
data = textwrap.dedent(data) h = self._read(data)
data = data.strip()
s = cStringIO.StringIO(data)
h = http.read_headers(s)
assert h.lst == [["Header", "one"], ["Header2", "two"]] assert h.lst == [["Header", "one"], ["Header2", "two"]]
def test_read_multi(self): def test_read_multi(self):
@ -187,10 +191,7 @@ class TestReadHeaders:
Header: two Header: two
\r\n \r\n
""" """
data = textwrap.dedent(data) h = self._read(data)
data = data.strip()
s = cStringIO.StringIO(data)
h = http.read_headers(s)
assert h.lst == [["Header", "one"], ["Header", "two"]] assert h.lst == [["Header", "one"], ["Header", "two"]]
def test_read_continued(self): def test_read_continued(self):
@ -200,12 +201,19 @@ class TestReadHeaders:
Header2: three Header2: three
\r\n \r\n
""" """
data = textwrap.dedent(data) h = self._read(data)
data = data.strip()
s = cStringIO.StringIO(data)
h = http.read_headers(s)
assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]]
def test_read_continued_err(self):
data = "\tfoo: bar\r\n"
assert self._read(data, True) is None
def test_read_err(self):
data = """
foo
"""
assert self._read(data) is None
def test_read_response(): def test_read_response():
def tst(data, method, limit): def tst(data, method, limit):
@ -248,6 +256,14 @@ def test_read_response():
assert tst(data, "GET", None)[4] == 'foo' assert tst(data, "GET", None)[4] == 'foo'
assert tst(data, "HEAD", None)[4] == '' assert tst(data, "HEAD", None)[4] == ''
data = """
HTTP/1.1 200 OK
\tContent-Length: 3
foo
"""
tutils.raises("invalid headers", tst, data, "GET", None)
def test_parse_url(): def test_parse_url():
assert not http.parse_url("") assert not http.parse_url("")