mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
ODict improvements
- Setting values now tries to preserve the existing order, rather than just appending to the end. - __repr__ now returns a repr of the tuple list. The old repr becomes a .format() method. This is clearer, makes troubleshooting easier, and doesn't assume all data in ODicts are header-like
This commit is contained in:
parent
aeebf31927
commit
0c85c72dc4
@ -13,7 +13,8 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
|
|||||||
|
|
||||||
class ODict(object):
|
class ODict(object):
|
||||||
"""
|
"""
|
||||||
A dictionary-like object for managing ordered (key, value) data.
|
A dictionary-like object for managing ordered (key, value) data. Think
|
||||||
|
about it as a convenient interface to a list of (key, value) tuples.
|
||||||
"""
|
"""
|
||||||
def __init__(self, lst=None):
|
def __init__(self, lst=None):
|
||||||
self.lst = lst or []
|
self.lst = lst or []
|
||||||
@ -64,11 +65,20 @@ class ODict(object):
|
|||||||
key, they are cleared.
|
key, they are cleared.
|
||||||
"""
|
"""
|
||||||
if isinstance(valuelist, basestring):
|
if isinstance(valuelist, basestring):
|
||||||
raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']")
|
raise ValueError(
|
||||||
|
"Expected list of values instead of string. "
|
||||||
new = self._filter_lst(k, self.lst)
|
"Example: odict['Host'] = ['www.example.com']"
|
||||||
for i in valuelist:
|
)
|
||||||
new.append([k, i])
|
kc = self._kconv(k)
|
||||||
|
new = []
|
||||||
|
for i in self.lst:
|
||||||
|
if self._kconv(i[0]) == kc:
|
||||||
|
if valuelist:
|
||||||
|
new.append([k, valuelist.pop(0)])
|
||||||
|
else:
|
||||||
|
new.append(i)
|
||||||
|
while valuelist:
|
||||||
|
new.append([k, valuelist.pop(0)])
|
||||||
self.lst = new
|
self.lst = new
|
||||||
|
|
||||||
def __delitem__(self, k):
|
def __delitem__(self, k):
|
||||||
@ -115,6 +125,9 @@ class ODict(object):
|
|||||||
self.lst.extend(other.lst)
|
self.lst.extend(other.lst)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
return repr(self.lst)
|
||||||
|
|
||||||
|
def format(self):
|
||||||
elements = []
|
elements = []
|
||||||
for itm in self.lst:
|
for itm in self.lst:
|
||||||
elements.append(itm[0] + ": " + str(itm[1]))
|
elements.append(itm[0] + ": " + str(itm[1]))
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
from __future__ import (absolute_import, print_function, division)
|
from __future__ import (absolute_import, print_function, division)
|
||||||
import cStringIO, urllib, time, traceback
|
import cStringIO
|
||||||
|
import urllib
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
from . import odict, tcp
|
from . import odict, tcp
|
||||||
|
|
||||||
|
|
||||||
@ -23,15 +26,18 @@ class Request(object):
|
|||||||
def date_time_string():
|
def date_time_string():
|
||||||
"""Return the current date and time formatted for a message header."""
|
"""Return the current date and time formatted for a message header."""
|
||||||
WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
||||||
MONTHS = [None,
|
MONTHS = [
|
||||||
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
|
None,
|
||||||
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
|
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
|
||||||
|
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'
|
||||||
|
]
|
||||||
now = time.time()
|
now = time.time()
|
||||||
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now)
|
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now)
|
||||||
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
|
s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
|
||||||
WEEKS[wd],
|
WEEKS[wd],
|
||||||
day, MONTHS[month], year,
|
day, MONTHS[month], year,
|
||||||
hh, mm, ss)
|
hh, mm, ss
|
||||||
|
)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
@ -100,6 +106,7 @@ class WSGIAdaptor(object):
|
|||||||
status = None,
|
status = None,
|
||||||
headers = None
|
headers = None
|
||||||
)
|
)
|
||||||
|
|
||||||
def write(data):
|
def write(data):
|
||||||
if not state["headers_sent"]:
|
if not state["headers_sent"]:
|
||||||
soc.write("HTTP/1.1 %s\r\n"%state["status"])
|
soc.write("HTTP/1.1 %s\r\n"%state["status"])
|
||||||
@ -108,7 +115,7 @@ class WSGIAdaptor(object):
|
|||||||
h["Server"] = [self.sversion]
|
h["Server"] = [self.sversion]
|
||||||
if 'date' not in h:
|
if 'date' not in h:
|
||||||
h["Date"] = [date_time_string()]
|
h["Date"] = [date_time_string()]
|
||||||
soc.write(str(h))
|
soc.write(h.format())
|
||||||
soc.write("\r\n")
|
soc.write("\r\n")
|
||||||
state["headers_sent"] = True
|
state["headers_sent"] = True
|
||||||
if data:
|
if data:
|
||||||
@ -130,7 +137,9 @@ class WSGIAdaptor(object):
|
|||||||
|
|
||||||
errs = cStringIO.StringIO()
|
errs = cStringIO.StringIO()
|
||||||
try:
|
try:
|
||||||
dataiter = self.app(self.make_environ(request, errs, **env), start_response)
|
dataiter = self.app(
|
||||||
|
self.make_environ(request, errs, **env), start_response
|
||||||
|
)
|
||||||
for i in dataiter:
|
for i in dataiter:
|
||||||
write(i)
|
write(i)
|
||||||
if not state["headers_sent"]:
|
if not state["headers_sent"]:
|
||||||
@ -143,5 +152,3 @@ class WSGIAdaptor(object):
|
|||||||
except Exception: # pragma: no cover
|
except Exception: # pragma: no cover
|
||||||
pass
|
pass
|
||||||
return errs.getvalue()
|
return errs.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,6 +53,7 @@ def test_connection_close():
|
|||||||
h["connection"] = ["close"]
|
h["connection"] = ["close"]
|
||||||
assert http.connection_close((1, 1), h)
|
assert http.connection_close((1, 1), h)
|
||||||
|
|
||||||
|
|
||||||
def test_get_header_tokens():
|
def test_get_header_tokens():
|
||||||
h = odict.ODictCaseless()
|
h = odict.ODictCaseless()
|
||||||
assert http.get_header_tokens(h, "foo") == []
|
assert http.get_header_tokens(h, "foo") == []
|
||||||
@ -69,11 +70,13 @@ def test_read_http_body_request():
|
|||||||
r = cStringIO.StringIO("testing")
|
r = cStringIO.StringIO("testing")
|
||||||
assert http.read_http_body(r, h, None, "GET", None, True) == ""
|
assert http.read_http_body(r, h, None, "GET", None, True) == ""
|
||||||
|
|
||||||
|
|
||||||
def test_read_http_body_response():
|
def test_read_http_body_response():
|
||||||
h = odict.ODictCaseless()
|
h = odict.ODictCaseless()
|
||||||
s = cStringIO.StringIO("testing")
|
s = cStringIO.StringIO("testing")
|
||||||
assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"
|
assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"
|
||||||
|
|
||||||
|
|
||||||
def test_read_http_body():
|
def test_read_http_body():
|
||||||
# test default case
|
# test default case
|
||||||
h = odict.ODictCaseless()
|
h = odict.ODictCaseless()
|
||||||
@ -115,6 +118,7 @@ def test_read_http_body():
|
|||||||
s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
|
s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
|
||||||
assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"
|
assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"
|
||||||
|
|
||||||
|
|
||||||
def test_expected_http_body_size():
|
def test_expected_http_body_size():
|
||||||
# gibber in the content-length field
|
# gibber in the content-length field
|
||||||
h = odict.ODictCaseless()
|
h = odict.ODictCaseless()
|
||||||
@ -135,6 +139,7 @@ def test_expected_http_body_size():
|
|||||||
h = odict.ODictCaseless()
|
h = odict.ODictCaseless()
|
||||||
assert http.expected_http_body_size(h, True, "GET", None) == 0
|
assert http.expected_http_body_size(h, True, "GET", None) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_parse_http_protocol():
|
def test_parse_http_protocol():
|
||||||
assert http.parse_http_protocol("HTTP/1.1") == (1, 1)
|
assert http.parse_http_protocol("HTTP/1.1") == (1, 1)
|
||||||
assert http.parse_http_protocol("HTTP/0.0") == (0, 0)
|
assert http.parse_http_protocol("HTTP/0.0") == (0, 0)
|
||||||
@ -189,6 +194,7 @@ def test_parse_init_http():
|
|||||||
assert not http.parse_init_http("GET /test foo/1.1")
|
assert not http.parse_init_http("GET /test foo/1.1")
|
||||||
assert not http.parse_init_http("GET /test\xc0 HTTP/1.1")
|
assert not http.parse_init_http("GET /test\xc0 HTTP/1.1")
|
||||||
|
|
||||||
|
|
||||||
class TestReadHeaders:
|
class TestReadHeaders:
|
||||||
def _read(self, data, verbatim=False):
|
def _read(self, data, verbatim=False):
|
||||||
if not verbatim:
|
if not verbatim:
|
||||||
@ -251,11 +257,12 @@ class TestReadResponseNoContentLength(test.ServerTestBase):
|
|||||||
httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None)
|
httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None)
|
||||||
assert content == "bar\r\n\r\n"
|
assert content == "bar\r\n\r\n"
|
||||||
|
|
||||||
|
|
||||||
def test_read_response():
|
def test_read_response():
|
||||||
def tst(data, method, limit, include_body=True):
|
def tst(data, method, limit, include_body=True):
|
||||||
data = textwrap.dedent(data)
|
data = textwrap.dedent(data)
|
||||||
r = cStringIO.StringIO(data)
|
r = cStringIO.StringIO(data)
|
||||||
return http.read_response(r, method, limit, include_body=include_body)
|
return http.read_response(r, method, limit, include_body = include_body)
|
||||||
|
|
||||||
tutils.raises("server disconnect", tst, "", "GET", None)
|
tutils.raises("server disconnect", tst, "", "GET", None)
|
||||||
tutils.raises("invalid server response", tst, "foo", "GET", None)
|
tutils.raises("invalid server response", tst, "foo", "GET", None)
|
||||||
@ -351,6 +358,7 @@ def test_parse_url():
|
|||||||
# Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt
|
# Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt
|
||||||
assert not http.parse_url('http://lo[calhost')
|
assert not http.parse_url('http://lo[calhost')
|
||||||
|
|
||||||
|
|
||||||
def test_parse_http_basic_auth():
|
def test_parse_http_basic_auth():
|
||||||
vals = ("basic", "foo", "bar")
|
vals = ("basic", "foo", "bar")
|
||||||
assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals
|
assert http.parse_http_basic_auth(http.assemble_http_basic_auth(*vals)) == vals
|
||||||
@ -358,4 +366,3 @@ def test_parse_http_basic_auth():
|
|||||||
assert not http.parse_http_basic_auth("foo bar")
|
assert not http.parse_http_basic_auth("foo bar")
|
||||||
v = "basic " + binascii.b2a_base64("foo")
|
v = "basic " + binascii.b2a_base64("foo")
|
||||||
assert not http.parse_http_basic_auth(v)
|
assert not http.parse_http_basic_auth(v)
|
||||||
|
|
||||||
|
@ -135,18 +135,6 @@ def test_cookie_roundtrips():
|
|||||||
nose.tools.eq_(ret.lst, lst)
|
nose.tools.eq_(ret.lst, lst)
|
||||||
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
# I've seen the following pathological cookie in the wild:
|
|
||||||
#
|
|
||||||
# cid=09,0,0,0,0; expires=Wed, 10-Jun-2015 21:54:53 GMT; path=/
|
|
||||||
#
|
|
||||||
# It's not compliant under any RFC - the latest RFC prohibits commas in cookie
|
|
||||||
# values completely, earlier RFCs require them to be within a quoted string.
|
|
||||||
#
|
|
||||||
# If we ditch support for earlier RFCs, we can handle this correctly. This
|
|
||||||
# leaves us with the question: what's more common, multiple-value Set-Cookie
|
|
||||||
# headers, or Set-Cookie headers that violate the standards?
|
|
||||||
|
|
||||||
def test_parse_set_cookie_pairs():
|
def test_parse_set_cookie_pairs():
|
||||||
pairs = [
|
pairs = [
|
||||||
[
|
[
|
||||||
@ -205,6 +193,9 @@ def test_parse_set_cookie_header():
|
|||||||
[
|
[
|
||||||
"", None
|
"", None
|
||||||
],
|
],
|
||||||
|
[
|
||||||
|
";", None
|
||||||
|
],
|
||||||
[
|
[
|
||||||
"one=uno",
|
"one=uno",
|
||||||
("one", "uno", [])
|
("one", "uno", [])
|
||||||
|
@ -6,6 +6,11 @@ class TestODict:
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.od = odict.ODict()
|
self.od = odict.ODict()
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
h = odict.ODict()
|
||||||
|
h["one"] = ["two"]
|
||||||
|
assert repr(h)
|
||||||
|
|
||||||
def test_str_err(self):
|
def test_str_err(self):
|
||||||
h = odict.ODict()
|
h = odict.ODict()
|
||||||
tutils.raises(ValueError, h.__setitem__, "key", "foo")
|
tutils.raises(ValueError, h.__setitem__, "key", "foo")
|
||||||
@ -20,7 +25,7 @@ class TestODict:
|
|||||||
"two: tre\r\n",
|
"two: tre\r\n",
|
||||||
"\r\n"
|
"\r\n"
|
||||||
]
|
]
|
||||||
out = repr(self.od)
|
out = self.od.format()
|
||||||
for i in expected:
|
for i in expected:
|
||||||
assert out.find(i) >= 0
|
assert out.find(i) >= 0
|
||||||
|
|
||||||
@ -39,7 +44,7 @@ class TestODict:
|
|||||||
self.od["one"] = ["uno"]
|
self.od["one"] = ["uno"]
|
||||||
expected1 = "one: uno\r\n"
|
expected1 = "one: uno\r\n"
|
||||||
expected2 = "\r\n"
|
expected2 = "\r\n"
|
||||||
out = repr(self.od)
|
out = self.od.format()
|
||||||
assert out.find(expected1) >= 0
|
assert out.find(expected1) >= 0
|
||||||
assert out.find(expected2) >= 0
|
assert out.find(expected2) >= 0
|
||||||
|
|
||||||
@ -150,3 +155,19 @@ class TestODictCaseless:
|
|||||||
assert self.od.keys() == ["foo"]
|
assert self.od.keys() == ["foo"]
|
||||||
self.od.add("bar", 2)
|
self.od.add("bar", 2)
|
||||||
assert len(self.od.keys()) == 2
|
assert len(self.od.keys()) == 2
|
||||||
|
|
||||||
|
def test_add_order(self):
|
||||||
|
od = odict.ODict(
|
||||||
|
[
|
||||||
|
["one", "uno"],
|
||||||
|
["two", "due"],
|
||||||
|
["three", "tre"],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
od["two"] = ["foo", "bar"]
|
||||||
|
assert od.lst == [
|
||||||
|
["one", "uno"],
|
||||||
|
["two", "foo"],
|
||||||
|
["three", "tre"],
|
||||||
|
["two", "bar"],
|
||||||
|
]
|
||||||
|
@ -100,4 +100,3 @@ class TestWSGI:
|
|||||||
start_response(status, response_headers, ei)
|
start_response(status, response_headers, ei)
|
||||||
yield "bbb"
|
yield "bbb"
|
||||||
assert "Internal Server Error" in self._serve(app)
|
assert "Internal Server Error" in self._serve(app)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user