clean up odict

This commit is contained in:
Maximilian Hils 2016-05-18 19:01:49 -07:00
parent 4c3fb8f509
commit d1fc694952
2 changed files with 34 additions and 57 deletions

View File

@ -1,5 +1,6 @@
from __future__ import (absolute_import, print_function, division)
import copy
import six
from .utils import Serializable, safe_subn
@ -27,27 +28,24 @@ class ODict(Serializable):
def __iter__(self):
return self.lst.__iter__()
def __getitem__(self, k):
def __getitem__(self, key):
"""
Returns a list of values matching key.
"""
ret = []
k = self._kconv(k)
for i in self.lst:
if self._kconv(i[0]) == k:
ret.append(i[1])
return ret
key = self._kconv(key)
return [
v
for k, v in self.lst
if self._kconv(k) == key
]
def keys(self):
return list(set([self._kconv(i[0]) for i in self.lst]))
def _filter_lst(self, k, lst):
k = self._kconv(k)
new = []
for i in lst:
if self._kconv(i[0]) != k:
new.append(i)
return new
return list(
set(
self._kconv(k) for k, _ in self.lst
)
)
def __len__(self):
"""
@ -81,14 +79,19 @@ class ODict(Serializable):
"""
Delete all items matching k.
"""
self.lst = self._filter_lst(k, self.lst)
def __contains__(self, k):
k = self._kconv(k)
for i in self.lst:
if self._kconv(i[0]) == k:
return True
return False
self.lst = [
i
for i in self.lst
if self._kconv(i[0]) != k
]
def __contains__(self, key):
key = self._kconv(key)
return any(
self._kconv(k) == key
for k, _ in self.lst
)
def add(self, key, value, prepend=False):
if prepend:
@ -127,40 +130,24 @@ class ODict(Serializable):
def __repr__(self):
return repr(self.lst)
def in_any(self, key, value, caseless=False):
"""
Do any of the values matching key contain value?
If caseless is true, value comparison is case-insensitive.
"""
if caseless:
value = value.lower()
for i in self[key]:
if caseless:
i = i.lower()
if value in i:
return True
return False
def replace(self, pattern, repl, *args, **kwargs):
"""
Replaces a regular expression pattern with repl in both keys and
values. Encoded content will be decoded before replacement, and
re-encoded afterwards.
values.
Returns the number of replacements made.
"""
nlst, count = [], 0
for i in self.lst:
k, c = safe_subn(pattern, repl, i[0], *args, **kwargs)
new, count = [], 0
for k, v in self.lst:
k, c = safe_subn(pattern, repl, k, *args, **kwargs)
count += c
v, c = safe_subn(pattern, repl, i[1], *args, **kwargs)
v, c = safe_subn(pattern, repl, v, *args, **kwargs)
count += c
nlst.append([k, v])
self.lst = nlst
new.append([k, v])
self.lst = new
return count
# Implement the StateObject protocol from mitmproxy
# Implement Serializable
def get_state(self):
return [tuple(i) for i in self.lst]

View File

@ -27,16 +27,6 @@ class TestODict(object):
b.set_state(state)
assert b == od
def test_in_any(self):
od = odict.ODict()
od["one"] = ["atwoa", "athreea"]
assert od.in_any("one", "two")
assert od.in_any("one", "three")
assert not od.in_any("one", "four")
assert not od.in_any("nonexistent", "foo")
assert not od.in_any("one", "TWO")
assert od.in_any("one", "TWO", True)
def test_iter(self):
od = odict.ODict()
assert not [i for i in od]