From fe0ed63c4a3486402f65638b476149ebba752055 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:16:58 +0100 Subject: [PATCH 1/3] add Serializable ABC --- netlib/certutils.py | 26 +++++++++++++++++++------- netlib/http/headers.py | 7 +++---- netlib/http/message.py | 33 ++++++++++++++++++++++++++++++--- netlib/http/request.py | 5 ++--- netlib/http/response.py | 5 ++--- netlib/odict.py | 10 ++++++---- netlib/tcp.py | 17 ++++++++++++++++- netlib/utils.py | 26 +++++++++++++++++++++++++- test/http/test_headers.py | 2 +- test/http/test_request.py | 2 +- test/http/test_response.py | 2 +- test/test_odict.py | 2 +- 12 files changed, 107 insertions(+), 30 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index a0111381c..ecdc06241 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -13,6 +13,8 @@ from pyasn1.error import PyAsn1Error import OpenSSL # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 +from netlib.utils import Serializable + DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 # Generated with "openssl dhparam". It's too slow to generate this on startup. DEFAULT_DHPARAM = b""" @@ -361,7 +363,7 @@ class _GeneralNames(univ.SequenceOf): constraint.ValueSizeConstraint(1, 1024) -class SSLCert(object): +class SSLCert(Serializable): def __init__(self, cert): """ @@ -375,15 +377,25 @@ class SSLCert(object): def __ne__(self, other): return not self.__eq__(other) - @classmethod - def from_pem(klass, txt): - x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) - return klass(x509) + def get_state(self): + return self.to_pem() + + def set_state(self, state): + self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state) @classmethod - def from_der(klass, der): + def from_state(cls, state): + cls.from_pem(state) + + @classmethod + def from_pem(cls, txt): + x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt) + return cls(x509) + + @classmethod + def from_der(cls, der): pem = ssl.DER_cert_to_PEM_cert(der) - return klass.from_pem(pem) + return cls.from_pem(pem) def to_pem(self): return OpenSSL.crypto.dump_certificate( diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 6eb9db92d..784047964 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -14,7 +14,7 @@ except ImportError: # pragma: nocover import six -from netlib.utils import always_byte_args, always_bytes +from netlib.utils import always_byte_args, always_bytes, Serializable if six.PY2: # pragma: nocover _native = lambda x: x @@ -27,7 +27,7 @@ else: _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping): +class Headers(MutableMapping, Serializable): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -193,11 +193,10 @@ class Headers(MutableMapping): def copy(self): return Headers(copy.copy(self.fields)) - # Implement the StateObject protocol from mitmproxy def get_state(self): return tuple(tuple(field) for field in self.fields) - def load_state(self, state): + def set_state(self, state): self.fields = [list(field) for field in state] @classmethod diff --git a/netlib/http/message.py b/netlib/http/message.py index 28f55fa25..3d65f93e5 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,9 +4,10 @@ import warnings import six +from netlib.utils import Serializable +from .headers import Headers from .. import encoding, utils - CONTENT_MISSING = 0 if six.PY2: # pragma: nocover @@ -18,7 +19,7 @@ else: _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") -class MessageData(object): +class MessageData(Serializable): def __eq__(self, other): if isinstance(other, MessageData): return self.__dict__ == other.__dict__ @@ -27,8 +28,24 @@ class MessageData(object): def __ne__(self, other): return not self.__eq__(other) + def set_state(self, state): + for k, v in state.items(): + if k == "headers": + v = Headers.from_state(v) + setattr(self, k, v) -class Message(object): + def get_state(self): + state = vars(self).copy() + state["headers"] = state["headers"].get_state() + return state + + @classmethod + def from_state(cls, state): + state["headers"] = Headers.from_state(state["headers"]) + return cls(**state) + + +class Message(Serializable): def __init__(self, data): self.data = data @@ -40,6 +57,16 @@ class Message(object): def __ne__(self, other): return not self.__eq__(other) + def get_state(self): + return self.data.get_state() + + def set_state(self, state): + self.data.set_state(state) + + @classmethod + def from_state(cls, state): + return cls(**state) + @property def headers(self): """ diff --git a/netlib/http/request.py b/netlib/http/request.py index 6dabb1896..0e0f88cec 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -16,9 +16,8 @@ from .message import Message, _native, _always_bytes, MessageData class RequestData(MessageData): def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, timestamp_start=None, timestamp_end=None): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) + if not isinstance(headers, Headers): + headers = Headers(headers) self.first_line_format = first_line_format self.method = method diff --git a/netlib/http/response.py b/netlib/http/response.py index 66e5ded60..8f4d62158 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -12,9 +12,8 @@ from ..odict import ODict class ResponseData(MessageData): def __init__(self, http_version, status_code, reason=None, headers=None, content=None, timestamp_start=None, timestamp_end=None): - if not headers: - headers = Headers() - assert isinstance(headers, Headers) + if not isinstance(headers, Headers): + headers = Headers(headers) self.http_version = http_version self.status_code = status_code diff --git a/netlib/odict.py b/netlib/odict.py index 90317e5e8..1e6e381af 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -3,6 +3,8 @@ import re import copy import six +from .utils import Serializable + def safe_subn(pattern, repl, target, *args, **kwargs): """ @@ -13,7 +15,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs): return re.subn(str(pattern), str(repl), target, *args, **kwargs) -class ODict(object): +class ODict(Serializable): """ A dictionary-like object for managing ordered (key, value) data. Think @@ -172,12 +174,12 @@ class ODict(object): def get_state(self): return [tuple(i) for i in self.lst] - def load_state(self, state): + def set_state(self, state): self.lst = [list(i) for i in state] @classmethod - def from_state(klass, state): - return klass([list(i) for i in state]) + def from_state(cls, state): + return cls([list(i) for i in state]) class ODictCaseless(ODict): diff --git a/netlib/tcp.py b/netlib/tcp.py index 85b4b0e2e..2e91a70cd 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,6 +16,7 @@ import six import OpenSSL from OpenSSL import SSL +from netlib.utils import Serializable from . import certutils, version_check # This is a rather hackish way to make sure that @@ -298,7 +299,7 @@ class Reader(_FileLike): raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -class Address(object): +class Address(Serializable): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and @@ -309,6 +310,20 @@ class Address(object): self.address = tuple(address) self.use_ipv6 = use_ipv6 + def get_state(self): + return { + "address": self.address, + "use_ipv6": self.use_ipv6 + } + + def set_state(self, state): + self.address = state["address"] + self.use_ipv6 = state["use_ipv6"] + + @classmethod + def from_state(cls, state): + return Address(**state) + @classmethod def wrap(cls, t): if isinstance(t, cls): diff --git a/netlib/utils.py b/netlib/utils.py index 1c1b617ac..a0c2035cf 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,14 +1,38 @@ from __future__ import absolute_import, print_function, division import os.path import re -import string import codecs import unicodedata +from abc import ABCMeta, abstractmethod + import six from six.moves import urllib import hyperframe + +@six.add_metaclass(ABCMeta) +class Serializable(object): + """ + ABC for Python's pickle protocol __getstate__ and __setstate__. + """ + + @classmethod + @abstractmethod + def from_state(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + @abstractmethod + def get_state(self): + raise NotImplementedError() + + @abstractmethod + def set_state(self, state): + raise NotImplementedError() + + def always_bytes(unicode_or_bytes, *encode_args): if isinstance(unicode_or_bytes, six.text_type): return unicode_or_bytes.encode(*encode_args) diff --git a/test/http/test_headers.py b/test/http/test_headers.py index 8bddc0b2f..d50fee3e4 100644 --- a/test/http/test_headers.py +++ b/test/http/test_headers.py @@ -148,5 +148,5 @@ class TestHeaders(object): headers2 = Headers() assert headers != headers2 - headers2.load_state(headers.get_state()) + headers2.set_state(headers.get_state()) assert headers == headers2 diff --git a/test/http/test_request.py b/test/http/test_request.py index 8cf69ffe2..1deee3879 100644 --- a/test/http/test_request.py +++ b/test/http/test_request.py @@ -12,7 +12,7 @@ from .test_message import _test_decoded_attr, _test_passthrough_attr class TestRequestData(object): def test_init(self): - with raises(AssertionError): + with raises(ValueError): treq(headers="foobar") assert isinstance(treq(headers=None).headers, Headers) diff --git a/test/http/test_response.py b/test/http/test_response.py index a1f4abd72..c7d90b160 100644 --- a/test/http/test_response.py +++ b/test/http/test_response.py @@ -8,7 +8,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr class TestResponseData(object): def test_init(self): - with raises(AssertionError): + with raises(ValueError): tresp(headers="foobar") assert isinstance(tresp(headers=None).headers, Headers) diff --git a/test/test_odict.py b/test/test_odict.py index 881970263..f0985ef64 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -24,7 +24,7 @@ class TestODict(object): nd = odict.ODict.from_state(state) assert nd == od b = odict.ODict() - b.load_state(state) + b.set_state(state) assert b == od def test_in_any(self): From 173ff0b235cdb45a8923f313807d9804830c2a2b Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:28:49 +0100 Subject: [PATCH 2/3] fix py3 compat --- netlib/certutils.py | 3 ++- netlib/http/message.py | 5 ++--- netlib/tcp.py | 5 ++--- test/http/test_request.py | 2 +- test/http/test_response.py | 4 +++- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/netlib/certutils.py b/netlib/certutils.py index ecdc06241..616a778e4 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -12,8 +12,9 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL +from .utils import Serializable + # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 -from netlib.utils import Serializable DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 # Generated with "openssl dhparam". It's too slow to generate this on startup. diff --git a/netlib/http/message.py b/netlib/http/message.py index 3d65f93e5..e3d8ce375 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,7 +4,6 @@ import warnings import six -from netlib.utils import Serializable from .headers import Headers from .. import encoding, utils @@ -19,7 +18,7 @@ else: _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") -class MessageData(Serializable): +class MessageData(utils.Serializable): def __eq__(self, other): if isinstance(other, MessageData): return self.__dict__ == other.__dict__ @@ -45,7 +44,7 @@ class MessageData(Serializable): return cls(**state) -class Message(Serializable): +class Message(utils.Serializable): def __init__(self, data): self.data = data diff --git a/netlib/tcp.py b/netlib/tcp.py index 2e91a70cd..c8548aea2 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,8 +16,7 @@ import six import OpenSSL from OpenSSL import SSL -from netlib.utils import Serializable -from . import certutils, version_check +from . import certutils, version_check, utils # This is a rather hackish way to make sure that # the latest version of pyOpenSSL is actually installed. @@ -299,7 +298,7 @@ class Reader(_FileLike): raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -class Address(Serializable): +class Address(utils.Serializable): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and diff --git a/test/http/test_request.py b/test/http/test_request.py index 1deee3879..900b2cd1c 100644 --- a/test/http/test_request.py +++ b/test/http/test_request.py @@ -12,7 +12,7 @@ from .test_message import _test_decoded_attr, _test_passthrough_attr class TestRequestData(object): def test_init(self): - with raises(ValueError): + with raises(ValueError if six.PY2 else TypeError): treq(headers="foobar") assert isinstance(treq(headers=None).headers, Headers) diff --git a/test/http/test_response.py b/test/http/test_response.py index c7d90b160..145880000 100644 --- a/test/http/test_response.py +++ b/test/http/test_response.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, print_function, division +import six + from netlib.http import Headers from netlib.odict import ODict, ODictCaseless from netlib.tutils import raises, tresp @@ -8,7 +10,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr class TestResponseData(object): def test_init(self): - with raises(ValueError): + with raises(ValueError if six.PY2 else TypeError): tresp(headers="foobar") assert isinstance(tresp(headers=None).headers, Headers) From 655b521749efd5a600d342a1d95b67d32da280a8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:33:10 +0100 Subject: [PATCH 3/3] fix docstrings --- netlib/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/netlib/utils.py b/netlib/utils.py index a0c2035cf..d2fc7195b 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -14,22 +14,29 @@ import hyperframe @six.add_metaclass(ABCMeta) class Serializable(object): """ - ABC for Python's pickle protocol __getstate__ and __setstate__. + Abstract Base Class that defines an API to save an object's state and restore it later on. """ @classmethod @abstractmethod def from_state(cls, state): - obj = cls.__new__(cls) - obj.__setstate__(state) - return obj + """ + Create a new object from the given state. + """ + raise NotImplementedError() @abstractmethod def get_state(self): + """ + Retrieve object state. + """ raise NotImplementedError() @abstractmethod def set_state(self, state): + """ + Set object state to the given state. + """ raise NotImplementedError()