ODict improvements

- Setting values now tries to preserve the existing order, rather than
just appending to the end.
- __repr__ now returns  a repr of the tuple list. The old repr becomes a
.format() method. This is clearer, makes troubleshooting easier, and
doesn't assume all data in ODicts are header-like
This commit is contained in:
Aldo Cortesi 2015-04-15 10:28:17 +12:00
parent aeebf31927
commit 0c85c72dc4
6 changed files with 72 additions and 34 deletions

View File

@ -13,7 +13,8 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
class ODict(object): class ODict(object):
""" """
A dictionary-like object for managing ordered (key, value) data. A dictionary-like object for managing ordered (key, value) data. Think
about it as a convenient interface to a list of (key, value) tuples.
""" """
def __init__(self, lst=None): def __init__(self, lst=None):
self.lst = lst or [] self.lst = lst or []
@ -64,11 +65,20 @@ class ODict(object):
key, they are cleared. key, they are cleared.
""" """
if isinstance(valuelist, basestring): if isinstance(valuelist, basestring):
raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") raise ValueError(
"Expected list of values instead of string. "
new = self._filter_lst(k, self.lst) "Example: odict['Host'] = ['www.example.com']"
for i in valuelist: )
new.append([k, i]) kc = self._kconv(k)
new = []
for i in self.lst:
if self._kconv(i[0]) == kc:
if valuelist:
new.append([k, valuelist.pop(0)])
else:
new.append(i)
while valuelist:
new.append([k, valuelist.pop(0)])
self.lst = new self.lst = new
def __delitem__(self, k): def __delitem__(self, k):
@ -115,6 +125,9 @@ class ODict(object):
self.lst.extend(other.lst) self.lst.extend(other.lst)
def __repr__(self): def __repr__(self):
return repr(self.lst)
def format(self):
elements = [] elements = []
for itm in self.lst: for itm in self.lst:
elements.append(itm[0] + ": " + str(itm[1])) elements.append(itm[0] + ": " + str(itm[1]))

View File

@ -1,5 +1,8 @@
from __future__ import (absolute_import, print_function, division) from __future__ import (absolute_import, print_function, division)
import cStringIO, urllib, time, traceback import cStringIO
import urllib
import time
import traceback
from . import odict, tcp from . import odict, tcp
@ -23,15 +26,18 @@ class Request(object):
def date_time_string(): def date_time_string():
"""Return the current date and time formatted for a message header.""" """Return the current date and time formatted for a message header."""
WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
MONTHS = [None, MONTHS = [
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', None,
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'
]
now = time.time() now = time.time()
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now)
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
WEEKS[wd], WEEKS[wd],
day, MONTHS[month], year, day, MONTHS[month], year,
hh, mm, ss) hh, mm, ss
)
return s return s
@ -100,6 +106,7 @@ class WSGIAdaptor(object):
status = None, status = None,
headers = None headers = None
) )
def write(data): def write(data):
if not state["headers_sent"]: if not state["headers_sent"]:
soc.write("HTTP/1.1 %s\r\n"%state["status"]) soc.write("HTTP/1.1 %s\r\n"%state["status"])
@ -108,7 +115,7 @@ class WSGIAdaptor(object):
h["Server"] = [self.sversion] h["Server"] = [self.sversion]
if 'date' not in h: if 'date' not in h:
h["Date"] = [date_time_string()] h["Date"] = [date_time_string()]
soc.write(str(h)) soc.write(h.format())
soc.write("\r\n") soc.write("\r\n")
state["headers_sent"] = True state["headers_sent"] = True
if data: if data:
@ -130,7 +137,9 @@ class WSGIAdaptor(object):
errs = cStringIO.StringIO() errs = cStringIO.StringIO()
try: try:
dataiter = self.app(self.make_environ(request, errs, **env), start_response) dataiter = self.app(
self.make_environ(request, errs, **env), start_response
)
for i in dataiter: for i in dataiter:
write(i) write(i)
if not state["headers_sent"]: if not state["headers_sent"]:
@ -143,5 +152,3 @@ class WSGIAdaptor(object):
except Exception: # pragma: no cover except Exception: # pragma: no cover
pass pass
return errs.getvalue() return errs.getvalue()

View File

@ -53,6 +53,7 @@ def test_connection_close():
h["connection"] = ["close"] h["connection"] = ["close"]
assert http.connection_close((1, 1), h) assert http.connection_close((1, 1), h)
def test_get_header_tokens(): def test_get_header_tokens():
h = odict.ODictCaseless() h = odict.ODictCaseless()
assert http.get_header_tokens(h, "foo") == [] assert http.get_header_tokens(h, "foo") == []
@ -69,11 +70,13 @@ def test_read_http_body_request():
r = cStringIO.StringIO("testing") r = cStringIO.StringIO("testing")
assert http.read_http_body(r, h, None, "GET", None, True) == "" assert http.read_http_body(r, h, None, "GET", None, True) == ""
def test_read_http_body_response(): def test_read_http_body_response():
h = odict.ODictCaseless() h = odict.ODictCaseless()
s = cStringIO.StringIO("testing") s = cStringIO.StringIO("testing")
assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"
def test_read_http_body(): def test_read_http_body():
# test default case # test default case
h = odict.ODictCaseless() h = odict.ODictCaseless()
@ -115,6 +118,7 @@ def test_read_http_body():
s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"
def test_expected_http_body_size(): def test_expected_http_body_size():
# gibber in the content-length field # gibber in the content-length field
h = odict.ODictCaseless() h = odict.ODictCaseless()
@ -135,6 +139,7 @@ def test_expected_http_body_size():
h = odict.ODictCaseless() h = odict.ODictCaseless()
assert http.expected_http_body_size(h, True, "GET", None) == 0 assert http.expected_http_body_size(h, True, "GET", None) == 0
def test_parse_http_protocol(): def test_parse_http_protocol():
assert http.parse_http_protocol("HTTP/1.1") == (1, 1) assert http.parse_http_protocol("HTTP/1.1") == (1, 1)
assert http.parse_http_protocol("HTTP/0.0") == (0, 0) assert http.parse_http_protocol("HTTP/0.0") == (0, 0)
@ -189,6 +194,7 @@ def test_parse_init_http():
assert not http.parse_init_http("GET /test foo/1.1") assert not http.parse_init_http("GET /test foo/1.1")
assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") assert not http.parse_init_http("GET /test\xc0 HTTP/1.1")
class TestReadHeaders: class TestReadHeaders:
def _read(self, data, verbatim=False): def _read(self, data, verbatim=False):
if not verbatim: if not verbatim:
@ -251,11 +257,12 @@ class TestReadResponseNoContentLength(test.ServerTestBase):
httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None)
assert content == "bar\r\n\r\n" assert content == "bar\r\n\r\n"
def test_read_response(): def test_read_response():
def tst(data, method, limit, include_body=True): def tst(data, method, limit, include_body=True):
data = textwrap.dedent(data) data = textwrap.dedent(data)
r = cStringIO.StringIO(data) r = cStringIO.StringIO(data)
return http.read_response(r, method, limit, include_body=include_body) return http.read_response(r, method, limit, include_body = include_body)
tutils.raises("server disconnect", tst, "", "GET", None) tutils.raises("server disconnect", tst, "", "GET", None)
tutils.raises("invalid server response", tst, "foo", "GET", None) tutils.raises("invalid server response", tst, "foo", "GET", None)
@ -351,6 +358,7 @@ def test_parse_url():
# Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt
assert not http.parse_url('http://lo[calhost') assert not http.parse_url('http://lo[calhost')
def test_parse_http_basic_auth(): def test_parse_http_basic_auth():
vals = ("basic", "foo", "bar") vals = ("basic", "foo", "bar")
assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals
@ -358,4 +366,3 @@ def test_parse_http_basic_auth():
assert not http.parse_http_basic_auth("foo bar") assert not http.parse_http_basic_auth("foo bar")
v = "basic " + binascii.b2a_base64("foo") v = "basic " + binascii.b2a_base64("foo")
assert not http.parse_http_basic_auth(v) assert not http.parse_http_basic_auth(v)

View File

@ -135,18 +135,6 @@ def test_cookie_roundtrips():
nose.tools.eq_(ret.lst, lst) nose.tools.eq_(ret.lst, lst)
# TODO
# I've seen the following pathological cookie in the wild:
#
# cid=09,0,0,0,0; expires=Wed, 10-Jun-2015 21:54:53 GMT; path=/
#
# It's not compliant under any RFC - the latest RFC prohibits commas in cookie
# values completely, earlier RFCs require them to be within a quoted string.
#
# If we ditch support for earlier RFCs, we can handle this correctly. This
# leaves us with the question: what's more common, multiple-value Set-Cookie
# headers, or Set-Cookie headers that violate the standards?
def test_parse_set_cookie_pairs(): def test_parse_set_cookie_pairs():
pairs = [ pairs = [
[ [
@ -205,6 +193,9 @@ def test_parse_set_cookie_header():
[ [
"", None "", None
], ],
[
";", None
],
[ [
"one=uno", "one=uno",
("one", "uno", []) ("one", "uno", [])

View File

@ -6,6 +6,11 @@ class TestODict:
def setUp(self): def setUp(self):
self.od = odict.ODict() self.od = odict.ODict()
def test_repr(self):
h = odict.ODict()
h["one"] = ["two"]
assert repr(h)
def test_str_err(self): def test_str_err(self):
h = odict.ODict() h = odict.ODict()
tutils.raises(ValueError, h.__setitem__, "key", "foo") tutils.raises(ValueError, h.__setitem__, "key", "foo")
@ -20,7 +25,7 @@ class TestODict:
"two: tre\r\n", "two: tre\r\n",
"\r\n" "\r\n"
] ]
out = repr(self.od) out = self.od.format()
for i in expected: for i in expected:
assert out.find(i) >= 0 assert out.find(i) >= 0
@ -39,7 +44,7 @@ class TestODict:
self.od["one"] = ["uno"] self.od["one"] = ["uno"]
expected1 = "one: uno\r\n" expected1 = "one: uno\r\n"
expected2 = "\r\n" expected2 = "\r\n"
out = repr(self.od) out = self.od.format()
assert out.find(expected1) >= 0 assert out.find(expected1) >= 0
assert out.find(expected2) >= 0 assert out.find(expected2) >= 0
@ -150,3 +155,19 @@ class TestODictCaseless:
assert self.od.keys() == ["foo"] assert self.od.keys() == ["foo"]
self.od.add("bar", 2) self.od.add("bar", 2)
assert len(self.od.keys()) == 2 assert len(self.od.keys()) == 2
def test_add_order(self):
od = odict.ODict(
[
["one", "uno"],
["two", "due"],
["three", "tre"],
]
)
od["two"] = ["foo", "bar"]
assert od.lst == [
["one", "uno"],
["two", "foo"],
["three", "tre"],
["two", "bar"],
]

View File

@ -100,4 +100,3 @@ class TestWSGI:
start_response(status, response_headers, ei) start_response(status, response_headers, ei)
yield "bbb" yield "bbb"
assert "Internal Server Error" in self._serve(app) assert "Internal Server Error" in self._serve(app)