add Serializable ABC

This commit is contained in:
Maximilian Hils 2016-02-08 04:16:58 +01:00
parent 4873547de3
commit fe0ed63c4a
12 changed files with 107 additions and 30 deletions

View File

@ -13,6 +13,8 @@ from pyasn1.error import PyAsn1Error
import OpenSSL
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
from netlib.utils import Serializable
DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3
# Generated with "openssl dhparam". It's too slow to generate this on startup.
DEFAULT_DHPARAM = b"""
@ -361,7 +363,7 @@ class _GeneralNames(univ.SequenceOf):
constraint.ValueSizeConstraint(1, 1024)
class SSLCert(object):
class SSLCert(Serializable):
def __init__(self, cert):
"""
@ -375,15 +377,25 @@ class SSLCert(object):
def __ne__(self, other):
return not self.__eq__(other)
@classmethod
def from_pem(klass, txt):
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
return klass(x509)
def get_state(self):
return self.to_pem()
def set_state(self, state):
self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state)
@classmethod
def from_der(klass, der):
def from_state(cls, state):
cls.from_pem(state)
@classmethod
def from_pem(cls, txt):
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
return cls(x509)
@classmethod
def from_der(cls, der):
pem = ssl.DER_cert_to_PEM_cert(der)
return klass.from_pem(pem)
return cls.from_pem(pem)
def to_pem(self):
return OpenSSL.crypto.dump_certificate(

View File

@ -14,7 +14,7 @@ except ImportError: # pragma: nocover
import six
from netlib.utils import always_byte_args, always_bytes
from netlib.utils import always_byte_args, always_bytes, Serializable
if six.PY2: # pragma: nocover
_native = lambda x: x
@ -27,7 +27,7 @@ else:
_always_byte_args = always_byte_args("utf-8", "surrogateescape")
class Headers(MutableMapping):
class Headers(MutableMapping, Serializable):
"""
Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface.
@ -193,11 +193,10 @@ class Headers(MutableMapping):
def copy(self):
return Headers(copy.copy(self.fields))
# Implement the StateObject protocol from mitmproxy
def get_state(self):
return tuple(tuple(field) for field in self.fields)
def load_state(self, state):
def set_state(self, state):
self.fields = [list(field) for field in state]
@classmethod

View File

@ -4,9 +4,10 @@ import warnings
import six
from netlib.utils import Serializable
from .headers import Headers
from .. import encoding, utils
CONTENT_MISSING = 0
if six.PY2: # pragma: nocover
@ -18,7 +19,7 @@ else:
_always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape")
class MessageData(object):
class MessageData(Serializable):
def __eq__(self, other):
if isinstance(other, MessageData):
return self.__dict__ == other.__dict__
@ -27,8 +28,24 @@ class MessageData(object):
def __ne__(self, other):
return not self.__eq__(other)
def set_state(self, state):
for k, v in state.items():
if k == "headers":
v = Headers.from_state(v)
setattr(self, k, v)
class Message(object):
def get_state(self):
state = vars(self).copy()
state["headers"] = state["headers"].get_state()
return state
@classmethod
def from_state(cls, state):
state["headers"] = Headers.from_state(state["headers"])
return cls(**state)
class Message(Serializable):
def __init__(self, data):
self.data = data
@ -40,6 +57,16 @@ class Message(object):
def __ne__(self, other):
return not self.__eq__(other)
def get_state(self):
return self.data.get_state()
def set_state(self, state):
self.data.set_state(state)
@classmethod
def from_state(cls, state):
return cls(**state)
@property
def headers(self):
"""

View File

@ -16,9 +16,8 @@ from .message import Message, _native, _always_bytes, MessageData
class RequestData(MessageData):
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None,
timestamp_start=None, timestamp_end=None):
if not headers:
headers = Headers()
assert isinstance(headers, Headers)
if not isinstance(headers, Headers):
headers = Headers(headers)
self.first_line_format = first_line_format
self.method = method

View File

@ -12,9 +12,8 @@ from ..odict import ODict
class ResponseData(MessageData):
def __init__(self, http_version, status_code, reason=None, headers=None, content=None,
timestamp_start=None, timestamp_end=None):
if not headers:
headers = Headers()
assert isinstance(headers, Headers)
if not isinstance(headers, Headers):
headers = Headers(headers)
self.http_version = http_version
self.status_code = status_code

View File

@ -3,6 +3,8 @@ import re
import copy
import six
from .utils import Serializable
def safe_subn(pattern, repl, target, *args, **kwargs):
"""
@ -13,7 +15,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
class ODict(object):
class ODict(Serializable):
"""
A dictionary-like object for managing ordered (key, value) data. Think
@ -172,12 +174,12 @@ class ODict(object):
def get_state(self):
return [tuple(i) for i in self.lst]
def load_state(self, state):
def set_state(self, state):
self.lst = [list(i) for i in state]
@classmethod
def from_state(klass, state):
return klass([list(i) for i in state])
def from_state(cls, state):
return cls([list(i) for i in state])
class ODictCaseless(ODict):

View File

@ -16,6 +16,7 @@ import six
import OpenSSL
from OpenSSL import SSL
from netlib.utils import Serializable
from . import certutils, version_check
# This is a rather hackish way to make sure that
@ -298,7 +299,7 @@ class Reader(_FileLike):
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
class Address(object):
class Address(Serializable):
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and
@ -309,6 +310,20 @@ class Address(object):
self.address = tuple(address)
self.use_ipv6 = use_ipv6
def get_state(self):
return {
"address": self.address,
"use_ipv6": self.use_ipv6
}
def set_state(self, state):
self.address = state["address"]
self.use_ipv6 = state["use_ipv6"]
@classmethod
def from_state(cls, state):
return Address(**state)
@classmethod
def wrap(cls, t):
if isinstance(t, cls):

View File

@ -1,14 +1,38 @@
from __future__ import absolute_import, print_function, division
import os.path
import re
import string
import codecs
import unicodedata
from abc import ABCMeta, abstractmethod
import six
from six.moves import urllib
import hyperframe
@six.add_metaclass(ABCMeta)
class Serializable(object):
"""
ABC for Python's pickle protocol __getstate__ and __setstate__.
"""
@classmethod
@abstractmethod
def from_state(cls, state):
obj = cls.__new__(cls)
obj.__setstate__(state)
return obj
@abstractmethod
def get_state(self):
raise NotImplementedError()
@abstractmethod
def set_state(self, state):
raise NotImplementedError()
def always_bytes(unicode_or_bytes, *encode_args):
if isinstance(unicode_or_bytes, six.text_type):
return unicode_or_bytes.encode(*encode_args)

View File

@ -148,5 +148,5 @@ class TestHeaders(object):
headers2 = Headers()
assert headers != headers2
headers2.load_state(headers.get_state())
headers2.set_state(headers.get_state())
assert headers == headers2

View File

@ -12,7 +12,7 @@ from .test_message import _test_decoded_attr, _test_passthrough_attr
class TestRequestData(object):
def test_init(self):
with raises(AssertionError):
with raises(ValueError):
treq(headers="foobar")
assert isinstance(treq(headers=None).headers, Headers)

View File

@ -8,7 +8,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr
class TestResponseData(object):
def test_init(self):
with raises(AssertionError):
with raises(ValueError):
tresp(headers="foobar")
assert isinstance(tresp(headers=None).headers, Headers)

View File

@ -24,7 +24,7 @@ class TestODict(object):
nd = odict.ODict.from_state(state)
assert nd == od
b = odict.ODict()
b.load_state(state)
b.set_state(state)
assert b == od
def test_in_any(self):