diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py index fff6e116a..765c35d6c 100644 --- a/mitmproxy/stateobject.py +++ b/mitmproxy/stateobject.py @@ -1,10 +1,18 @@ from __future__ import absolute_import import six +from typing import List, Any from netlib.utils import Serializable +def _is_list(cls): + # The typing module backport is somewhat broken. + # Python 3.5 or 3.6 should fix this. + is_list_bugfix = getattr(cls, "__origin__", False) == getattr(List[Any], "__origin__", True) + return issubclass(cls, List) or is_list_bugfix + + class StateObject(Serializable): """ @@ -28,8 +36,12 @@ class StateObject(Serializable): state = {} for attr, cls in six.iteritems(self._stateobject_attributes): val = getattr(self, attr) - if hasattr(val, "get_state"): + if val is None: + state[attr] = None + elif hasattr(val, "get_state"): state[attr] = val.get_state() + elif _is_list(cls): + state[attr] = [x.get_state() for x in val] else: state[attr] = val return state @@ -49,6 +61,9 @@ class StateObject(Serializable): elif hasattr(cls, "from_state"): obj = cls.from_state(state.pop(attr)) setattr(self, attr, obj) + elif _is_list(cls): + cls = cls.__parameters__[0] + setattr(self, attr, [cls.from_state(x) for x in state.pop(attr)]) else: # primitive types such as int, str, ... setattr(self, attr, cls(state.pop(attr))) if state: diff --git a/setup.py b/setup.py index f20073292..a0777d028 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,7 @@ setup( "requests>=2.9.1, <2.10", "six>=1.10, <1.11", "tornado>=4.3, <4.4", + "typing==3.5.1.0", "urwid>=1.3.1, <1.4", "watchdog>=0.8.3, <0.9", ], diff --git a/test/mitmproxy/test_stateobject.py b/test/mitmproxy/test_stateobject.py new file mode 100644 index 000000000..b9ffe7ae6 --- /dev/null +++ b/test/mitmproxy/test_stateobject.py @@ -0,0 +1,63 @@ +from typing import List + +from mitmproxy.stateobject import StateObject + + +class Child(StateObject): + def __init__(self, x): + self.x = x + + _stateobject_attributes = dict( + x=int + ) + + @classmethod + def from_state(cls, state): + obj = cls(None) + obj.set_state(state) + return obj + + +class Container(StateObject): + def __init__(self): + self.child = None + self.children = None + + _stateobject_attributes = dict( + child=Child, + children=List[Child], + ) + + @classmethod + def from_state(cls, state): + obj = cls() + obj.set_state(state) + return obj + + +def test_simple(): + a = Child(42) + b = a.copy() + assert b.get_state() == {"x": 42} + a.set_state({"x": 44}) + assert a.x == 44 + assert b.x == 42 + + +def test_container(): + a = Container() + a.child = Child(42) + b = a.copy() + assert a.child.x == b.child.x + b.child.x = 44 + assert a.child.x != b.child.x + + +def test_container_list(): + a = Container() + a.children = [Child(42), Child(44)] + assert a.get_state() == { + "child": None, + "children": [{"x": 42}, {"x": 44}] + } + assert len(a.copy().children) == 2