mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
finalize Headers, add tests
This commit is contained in:
parent
5f97701958
commit
3718e59308
@ -51,8 +51,8 @@ class Headers(UserDict.DictMixin):
|
||||
Host: example.com
|
||||
Accept: application/text
|
||||
|
||||
# For full control, the raw header lines can be accessed
|
||||
>>> h.lines
|
||||
# For full control, the raw header fields can be accessed
|
||||
>>> h.fields
|
||||
|
||||
# Headers can also be crated from keyword arguments
|
||||
>>> h = Headers(host="example.com", content_type="application/xml")
|
||||
@ -61,85 +61,112 @@ class Headers(UserDict.DictMixin):
|
||||
For use with the "Set-Cookie" header, see :py:meth:`get_all`.
|
||||
"""
|
||||
|
||||
def __init__(self, lines=None, **headers):
|
||||
def __init__(self, fields=None, **headers):
|
||||
"""
|
||||
For convenience, underscores in header names will be transformed to dashes.
|
||||
This behaviour does not extend to other methods.
|
||||
|
||||
If ``**headers`` contains multiple keys that have equal ``.lower()``s,
|
||||
the behavior is undefined.
|
||||
Args:
|
||||
fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]``
|
||||
**headers: Additional headers to set. Will overwrite existing values from `fields`.
|
||||
For convenience, underscores in header names will be transformed to dashes -
|
||||
this behaviour does not extend to other methods.
|
||||
If ``**headers`` contains multiple keys that have equal ``.lower()`` s,
|
||||
the behavior is undefined.
|
||||
"""
|
||||
self.lines = lines or []
|
||||
self.fields = fields or []
|
||||
|
||||
# content_type -> content-type
|
||||
headers = {k.replace("_", "-"): v for k, v in headers.iteritems()}
|
||||
headers = {
|
||||
name.replace("_", "-"): value
|
||||
for name, value in headers.iteritems()
|
||||
}
|
||||
self.update(headers)
|
||||
|
||||
def __str__(self):
|
||||
return "\r\n".join(": ".join(line) for line in self.lines)
|
||||
return "\r\n".join(": ".join(field) for field in self.fields)
|
||||
|
||||
def __getitem__(self, key):
|
||||
values = self.get_all(key)
|
||||
def __getitem__(self, name):
|
||||
values = self.get_all(name)
|
||||
if not values:
|
||||
raise KeyError(key)
|
||||
raise KeyError(name)
|
||||
else:
|
||||
return ", ".join(values)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
idx = self._index(key)
|
||||
def __setitem__(self, name, value):
|
||||
idx = self._index(name)
|
||||
|
||||
# To please the human eye, we insert at the same position the first existing header occured.
|
||||
if idx is not None:
|
||||
del self[key]
|
||||
self.lines.insert(idx, [key, value])
|
||||
del self[name]
|
||||
self.fields.insert(idx, [name, value])
|
||||
else:
|
||||
self.lines.append([key, value])
|
||||
self.fields.append([name, value])
|
||||
|
||||
def __delitem__(self, key):
|
||||
key = key.lower()
|
||||
self.lines = [
|
||||
line for line in self.lines
|
||||
if key != line[0].lower()
|
||||
]
|
||||
def __delitem__(self, name):
|
||||
if name not in self:
|
||||
raise KeyError(name)
|
||||
name = name.lower()
|
||||
self.fields = [
|
||||
field for field in self.fields
|
||||
if name != field[0].lower()
|
||||
]
|
||||
|
||||
def _index(self, key):
|
||||
key = key.lower()
|
||||
for i, line in enumerate(self):
|
||||
if line[0].lower() == key:
|
||||
def _index(self, name):
|
||||
name = name.lower()
|
||||
for i, field in enumerate(self.fields):
|
||||
if field[0].lower() == name:
|
||||
return i
|
||||
return None
|
||||
|
||||
def keys(self):
|
||||
return list(set(line[0] for line in self.lines))
|
||||
seen = set()
|
||||
names = []
|
||||
for name, _ in self.fields:
|
||||
name_lower = name.lower()
|
||||
if name_lower not in seen:
|
||||
seen.add(name_lower)
|
||||
names.append(name)
|
||||
return names
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.lines == other.lines
|
||||
if isinstance(other, Headers):
|
||||
return self.fields == other.fields
|
||||
return False
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def get_all(self, key, default=None):
|
||||
def get_all(self, name, default=None):
|
||||
"""
|
||||
Like :py:meth:`get`, but does not fold multiple headers into a single one.
|
||||
This is useful for Set-Cookie headers, which do not support folding.
|
||||
|
||||
See also: https://tools.ietf.org/html/rfc7230#section-3.2.2
|
||||
"""
|
||||
key = key.lower()
|
||||
values = [line[1] for line in self.lines if line[0].lower() == key]
|
||||
name = name.lower()
|
||||
values = [value for n, value in self.fields if n.lower() == name]
|
||||
return values or default
|
||||
|
||||
def set_all(self, key, values):
|
||||
def set_all(self, name, values):
|
||||
"""
|
||||
Explicitly set multiple headers for the given key.
|
||||
See: :py:meth:`get_all`
|
||||
"""
|
||||
if key in self:
|
||||
del self[key]
|
||||
self.lines.extend(
|
||||
[key, value] for value in values
|
||||
if name in self:
|
||||
del self[name]
|
||||
self.fields.extend(
|
||||
[name, value] for value in values
|
||||
)
|
||||
|
||||
# Implement the StateObject protocol from mitmproxy
|
||||
def get_state(self, short=False):
|
||||
return tuple(tuple(field) for field in self.fields)
|
||||
|
||||
def load_state(self, state):
|
||||
self.fields = [list(field) for field in state]
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
return cls([list(field) for field in state])
|
||||
|
||||
|
||||
class ProtocolMixin(object):
|
||||
def read_request(self, *args, **kwargs): # pragma: no cover
|
||||
|
@ -445,3 +445,148 @@ class TestResponse(object):
|
||||
v = resp.get_cookies()
|
||||
assert len(v) == 1
|
||||
assert v["foo"] == [["bar", odict.ODictCaseless()]]
|
||||
|
||||
|
||||
class TestHeaders(object):
|
||||
def _2host(self):
|
||||
return semantics.Headers(
|
||||
[
|
||||
["Host", "example.com"],
|
||||
["host", "example.org"]
|
||||
]
|
||||
)
|
||||
|
||||
def test_init(self):
|
||||
h = semantics.Headers()
|
||||
assert len(h) == 0
|
||||
|
||||
h = semantics.Headers([["Host", "example.com"]])
|
||||
assert len(h) == 1
|
||||
assert h["Host"] == "example.com"
|
||||
|
||||
h = semantics.Headers(Host="example.com")
|
||||
assert len(h) == 1
|
||||
assert h["Host"] == "example.com"
|
||||
|
||||
h = semantics.Headers(
|
||||
[["Host", "invalid"]],
|
||||
Host="example.com"
|
||||
)
|
||||
assert len(h) == 1
|
||||
assert h["Host"] == "example.com"
|
||||
|
||||
h = semantics.Headers(
|
||||
[["Host", "invalid"], ["Accept", "text/plain"]],
|
||||
Host="example.com"
|
||||
)
|
||||
assert len(h) == 2
|
||||
assert h["Host"] == "example.com"
|
||||
assert h["Accept"] == "text/plain"
|
||||
|
||||
def test_getitem(self):
|
||||
h = semantics.Headers(Host="example.com")
|
||||
assert h["Host"] == "example.com"
|
||||
assert h["host"] == "example.com"
|
||||
tutils.raises(KeyError, h.__getitem__, "Accept")
|
||||
|
||||
h = self._2host()
|
||||
assert h["Host"] == "example.com, example.org"
|
||||
|
||||
def test_str(self):
|
||||
h = semantics.Headers(Host="example.com")
|
||||
assert str(h) == "Host: example.com"
|
||||
|
||||
h = semantics.Headers([
|
||||
["Host", "example.com"],
|
||||
["Accept", "text/plain"]
|
||||
])
|
||||
assert str(h) == "Host: example.com\r\nAccept: text/plain"
|
||||
|
||||
def test_setitem(self):
|
||||
h = semantics.Headers()
|
||||
h["Host"] = "example.com"
|
||||
assert "Host" in h
|
||||
assert "host" in h
|
||||
assert h["Host"] == "example.com"
|
||||
|
||||
h["host"] = "example.org"
|
||||
assert "Host" in h
|
||||
assert "host" in h
|
||||
assert h["Host"] == "example.org"
|
||||
|
||||
h["accept"] = "text/plain"
|
||||
assert len(h) == 2
|
||||
assert "Accept" in h
|
||||
assert "Host" in h
|
||||
|
||||
h = self._2host()
|
||||
assert len(h.fields) == 2
|
||||
h["Host"] = "example.com"
|
||||
assert len(h.fields) == 1
|
||||
assert "Host" in h
|
||||
|
||||
def test_delitem(self):
|
||||
h = semantics.Headers(Host="example.com")
|
||||
assert len(h) == 1
|
||||
del h["host"]
|
||||
assert len(h) == 0
|
||||
try:
|
||||
del h["host"]
|
||||
except KeyError:
|
||||
assert True
|
||||
else:
|
||||
assert False
|
||||
|
||||
h = self._2host()
|
||||
del h["Host"]
|
||||
assert len(h) == 0
|
||||
|
||||
def test_keys(self):
|
||||
h = semantics.Headers(Host="example.com")
|
||||
assert len(h.keys()) == 1
|
||||
assert h.keys()[0] == "Host"
|
||||
|
||||
h = self._2host()
|
||||
assert len(h.keys()) == 1
|
||||
assert h.keys()[0] == "Host"
|
||||
|
||||
def test_eq_ne(self):
|
||||
h1 = semantics.Headers(Host="example.com")
|
||||
h2 = semantics.Headers(host="example.com")
|
||||
assert not (h1 == h2)
|
||||
assert h1 != h2
|
||||
|
||||
h1 = semantics.Headers(Host="example.com")
|
||||
h2 = semantics.Headers(Host="example.com")
|
||||
assert h1 == h2
|
||||
assert not (h1 != h2)
|
||||
|
||||
assert h1 != None
|
||||
|
||||
def test_get_all(self):
|
||||
h = self._2host()
|
||||
assert h.get_all("host") == ["example.com", "example.org"]
|
||||
assert h.get_all("accept", 42) is 42
|
||||
|
||||
def test_set_all(self):
|
||||
h = semantics.Headers(Host="example.com")
|
||||
h.set_all("Accept", ["text/plain"])
|
||||
assert len(h) == 2
|
||||
assert "accept" in h
|
||||
|
||||
h = self._2host()
|
||||
h.set_all("Host", ["example.org"])
|
||||
assert h["host"] == "example.org"
|
||||
|
||||
h.set_all("Host", ["example.org", "example.net"])
|
||||
assert h["host"] == "example.org, example.net"
|
||||
|
||||
def test_state(self):
|
||||
h = self._2host()
|
||||
assert len(h.get_state()) == 2
|
||||
assert h == semantics.Headers.from_state(h.get_state())
|
||||
|
||||
h2 = semantics.Headers()
|
||||
assert h != h2
|
||||
h2.load_state(h.get_state())
|
||||
assert h == h2
|
||||
|
Loading…
Reference in New Issue
Block a user