diff --git a/mitmproxy/models/http.py b/mitmproxy/models/http.py index 40460182e..4bba35f1b 100644 --- a/mitmproxy/models/http.py +++ b/mitmproxy/models/http.py @@ -26,14 +26,6 @@ class MessageMixin(object): return self.content return encoding.decode(ce, self.content) - def copy(self): - c = copy.copy(self) - if hasattr(self, "data"): # FIXME remove condition - c.data = copy.copy(self.data) - - c.headers = self.headers.copy() - return c - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both the headers diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 9b8fdae4a..bcb828da6 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -5,7 +5,7 @@ Unicode Handling See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ """ from __future__ import absolute_import, print_function, division -import copy + try: from collections.abc import MutableMapping except ImportError: # pragma: no cover @@ -190,9 +190,6 @@ class Headers(MutableMapping, Serializable): [name, value] for value in values ) - def copy(self): - return Headers(copy.copy(self.fields)) - def get_state(self): return tuple(tuple(field) for field in self.fields) diff --git a/netlib/http/message.py b/netlib/http/message.py index b6d846baa..b265ac4ff 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -43,9 +43,6 @@ class MessageData(utils.Serializable): class Message(utils.Serializable): - def __init__(self, data): - self.data = data - def __eq__(self, other): if isinstance(other, Message): return self.data == other.data @@ -62,6 +59,7 @@ class Message(utils.Serializable): @classmethod def from_state(cls, state): + state["headers"] = Headers.from_state(state["headers"]) return cls(**state) @property diff --git a/netlib/http/request.py b/netlib/http/request.py index d35c1874e..5bd2547ed 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -42,8 +42,7 @@ class Request(Message): An HTTP request. """ def __init__(self, *args, **kwargs): - data = RequestData(*args, **kwargs) - super(Request, self).__init__(data) + self.data = RequestData(*args, **kwargs) def __repr__(self): if self.host and self.port: diff --git a/netlib/http/response.py b/netlib/http/response.py index da2c81421..8af3c0412 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -29,8 +29,7 @@ class Response(Message): An HTTP response. """ def __init__(self, *args, **kwargs): - data = ResponseData(*args, **kwargs) - super(Response, self).__init__(data) + self.data = ResponseData(*args, **kwargs) def __repr__(self): if self.content: diff --git a/netlib/utils.py b/netlib/utils.py index f7bb5c4bd..09be29d92 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -41,6 +41,9 @@ class Serializable(object): """ raise NotImplementedError() + def copy(self): + return self.from_state(self.get_state()) + def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): diff --git a/test/netlib/test_utils.py b/test/netlib/test_utils.py index b096e5bca..fcb63eb2f 100644 --- a/test/netlib/test_utils.py +++ b/test/netlib/test_utils.py @@ -139,3 +139,30 @@ def test_parse_content_type(): v = p("text/html; charset=UTF-8") assert v == ('text', 'html', {'charset': 'UTF-8'}) + + +class SerializableDummy(utils.Serializable): + def __init__(self, i): + self.i = i + + def get_state(self): + return self.i + + def set_state(self, i): + self.i = i + + def from_state(self, state): + return type(self)(state) + + +class TestSerializable: + + def test_copy(self): + a = SerializableDummy(42) + assert a.i == 42 + b = a.copy() + assert b.i == 42 + + a.set_state(1) + assert a.i == 1 + assert b.i == 42