From 1c21a28e6423edf3b903191610b45345720e0458 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Mon, 30 Jul 2012 12:50:35 +1200 Subject: [PATCH] read_headers: handle some crashes, return None on invalid data. --- netlib/http.py | 10 ++++++++-- test/test_http.py | 40 ++++++++++++++++++++++++++++------------ 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 980d3f625..b71eb72db 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -36,8 +36,8 @@ def parse_url(url): def read_headers(fp): """ - Read a set of headers from a file pointer. Stop once a blank line - is reached. Return a ODictCaseless object. + Read a set of headers from a file pointer. Stop once a blank line is + reached. Return a ODictCaseless object, or None if headers are invalid. """ ret = [] name = '' @@ -46,6 +46,8 @@ def read_headers(fp): if not line or line == '\r\n' or line == '\n': break if line[0] in ' \t': + if not ret: + return None # continued header ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() else: @@ -55,6 +57,8 @@ def read_headers(fp): name = line[:i] value = line[i+1:].strip() ret.append([name, value]) + else: + return None return odict.ODictCaseless(ret) @@ -282,6 +286,8 @@ def read_response(rfile, method, body_size_limit): except ValueError: raise HttpError(502, "Invalid server response: %s"%repr(line)) headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") if code >= 100 and code <= 199: return read_response(rfile, method, body_size_limit) if method == "HEAD" or code == 204 or code == 304: diff --git a/test/test_http.py b/test/test_http.py index 0b83e65a1..a6161fbcc 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -169,16 +169,20 @@ def test_parse_init_http(): class TestReadHeaders: + def _read(self, data, verbatim=False): + if not verbatim: + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + return http.read_headers(s) + def test_read_simple(self): data = """ Header: one Header2: two \r\n """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = http.read_headers(s) + h = self._read(data) assert h.lst == [["Header", "one"], ["Header2", "two"]] def test_read_multi(self): @@ -187,10 +191,7 @@ class TestReadHeaders: Header: two \r\n """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = http.read_headers(s) + h = self._read(data) assert h.lst == [["Header", "one"], ["Header", "two"]] def test_read_continued(self): @@ -200,12 +201,19 @@ class TestReadHeaders: Header2: three \r\n """ - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - h = http.read_headers(s) + h = self._read(data) assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] + def test_read_continued_err(self): + data = "\tfoo: bar\r\n" + assert self._read(data, True) is None + + def test_read_err(self): + data = """ + foo + """ + assert self._read(data) is None + def test_read_response(): def tst(data, method, limit): @@ -248,6 +256,14 @@ def test_read_response(): assert tst(data, "GET", None)[4] == 'foo' assert tst(data, "HEAD", None)[4] == '' + data = """ + HTTP/1.1 200 OK + \tContent-Length: 3 + + foo + """ + tutils.raises("invalid headers", tst, data, "GET", None) + def test_parse_url(): assert not http.parse_url("")