stateobject: support lists

This commit is contained in:
Maximilian Hils 2016-04-29 20:34:12 -07:00
parent cb1119f3ee
commit 74cfd7a4e2
3 changed files with 80 additions and 1 deletions

View File

@ -1,10 +1,18 @@
from __future__ import absolute_import
import six
from typing import List, Any
from netlib.utils import Serializable
def _is_list(cls):
# The typing module backport is somewhat broken.
# Python 3.5 or 3.6 should fix this.
is_list_bugfix = getattr(cls, "__origin__", False) == getattr(List[Any], "__origin__", True)
return issubclass(cls, List) or is_list_bugfix
class StateObject(Serializable):
"""
@ -28,8 +36,12 @@ class StateObject(Serializable):
state = {}
for attr, cls in six.iteritems(self._stateobject_attributes):
val = getattr(self, attr)
if hasattr(val, "get_state"):
if val is None:
state[attr] = None
elif hasattr(val, "get_state"):
state[attr] = val.get_state()
elif _is_list(cls):
state[attr] = [x.get_state() for x in val]
else:
state[attr] = val
return state
@ -49,6 +61,9 @@ class StateObject(Serializable):
elif hasattr(cls, "from_state"):
obj = cls.from_state(state.pop(attr))
setattr(self, attr, obj)
elif _is_list(cls):
cls = cls.__parameters__[0]
setattr(self, attr, [cls.from_state(x) for x in state.pop(attr)])
else: # primitive types such as int, str, ...
setattr(self, attr, cls(state.pop(attr)))
if state:

View File

@ -81,6 +81,7 @@ setup(
"requests>=2.9.1, <2.10",
"six>=1.10, <1.11",
"tornado>=4.3, <4.4",
"typing==3.5.1.0",
"urwid>=1.3.1, <1.4",
"watchdog>=0.8.3, <0.9",
],

View File

@ -0,0 +1,63 @@
from typing import List
from mitmproxy.stateobject import StateObject
class Child(StateObject):
def __init__(self, x):
self.x = x
_stateobject_attributes = dict(
x=int
)
@classmethod
def from_state(cls, state):
obj = cls(None)
obj.set_state(state)
return obj
class Container(StateObject):
def __init__(self):
self.child = None
self.children = None
_stateobject_attributes = dict(
child=Child,
children=List[Child],
)
@classmethod
def from_state(cls, state):
obj = cls()
obj.set_state(state)
return obj
def test_simple():
a = Child(42)
b = a.copy()
assert b.get_state() == {"x": 42}
a.set_state({"x": 44})
assert a.x == 44
assert b.x == 42
def test_container():
a = Container()
a.child = Child(42)
b = a.copy()
assert a.child.x == b.child.x
b.child.x = 44
assert a.child.x != b.child.x
def test_container_list():
a = Container()
a.children = [Child(42), Child(44)]
assert a.get_state() == {
"child": None,
"children": [{"x": 42}, {"x": 44}]
}
assert len(a.copy().children) == 2