Merge pull request #120 from mitmproxy/model-cleanup

Model Cleanup
This commit is contained in:
Thomas Kriechbaumer 2016-02-08 09:52:29 +01:00
commit 4ee1ad88fc
12 changed files with 116 additions and 31 deletions

View File

@ -12,7 +12,10 @@ from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error from pyasn1.error import PyAsn1Error
import OpenSSL import OpenSSL
from .utils import Serializable
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3 DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3
# Generated with "openssl dhparam". It's too slow to generate this on startup. # Generated with "openssl dhparam". It's too slow to generate this on startup.
DEFAULT_DHPARAM = b""" DEFAULT_DHPARAM = b"""
@ -361,7 +364,7 @@ class _GeneralNames(univ.SequenceOf):
constraint.ValueSizeConstraint(1, 1024) constraint.ValueSizeConstraint(1, 1024)
class SSLCert(object): class SSLCert(Serializable):
def __init__(self, cert): def __init__(self, cert):
""" """
@ -375,15 +378,25 @@ class SSLCert(object):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
@classmethod def get_state(self):
def from_pem(klass, txt): return self.to_pem()
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
return klass(x509) def set_state(self, state):
self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state)
@classmethod @classmethod
def from_der(klass, der): def from_state(cls, state):
cls.from_pem(state)
@classmethod
def from_pem(cls, txt):
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
return cls(x509)
@classmethod
def from_der(cls, der):
pem = ssl.DER_cert_to_PEM_cert(der) pem = ssl.DER_cert_to_PEM_cert(der)
return klass.from_pem(pem) return cls.from_pem(pem)
def to_pem(self): def to_pem(self):
return OpenSSL.crypto.dump_certificate( return OpenSSL.crypto.dump_certificate(

View File

@ -14,7 +14,7 @@ except ImportError: # pragma: nocover
import six import six
from netlib.utils import always_byte_args, always_bytes from netlib.utils import always_byte_args, always_bytes, Serializable
if six.PY2: # pragma: nocover if six.PY2: # pragma: nocover
_native = lambda x: x _native = lambda x: x
@ -27,7 +27,7 @@ else:
_always_byte_args = always_byte_args("utf-8", "surrogateescape") _always_byte_args = always_byte_args("utf-8", "surrogateescape")
class Headers(MutableMapping): class Headers(MutableMapping, Serializable):
""" """
Header class which allows both convenient access to individual headers as well as Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface. direct access to the underlying raw data. Provides a full dictionary interface.
@ -193,11 +193,10 @@ class Headers(MutableMapping):
def copy(self): def copy(self):
return Headers(copy.copy(self.fields)) return Headers(copy.copy(self.fields))
# Implement the StateObject protocol from mitmproxy
def get_state(self): def get_state(self):
return tuple(tuple(field) for field in self.fields) return tuple(tuple(field) for field in self.fields)
def load_state(self, state): def set_state(self, state):
self.fields = [list(field) for field in state] self.fields = [list(field) for field in state]
@classmethod @classmethod

View File

@ -4,9 +4,9 @@ import warnings
import six import six
from .headers import Headers
from .. import encoding, utils from .. import encoding, utils
CONTENT_MISSING = 0 CONTENT_MISSING = 0
if six.PY2: # pragma: nocover if six.PY2: # pragma: nocover
@ -18,7 +18,7 @@ else:
_always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape")
class MessageData(object): class MessageData(utils.Serializable):
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, MessageData): if isinstance(other, MessageData):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
@ -27,8 +27,24 @@ class MessageData(object):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def set_state(self, state):
for k, v in state.items():
if k == "headers":
v = Headers.from_state(v)
setattr(self, k, v)
class Message(object): def get_state(self):
state = vars(self).copy()
state["headers"] = state["headers"].get_state()
return state
@classmethod
def from_state(cls, state):
state["headers"] = Headers.from_state(state["headers"])
return cls(**state)
class Message(utils.Serializable):
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
@ -40,6 +56,16 @@ class Message(object):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def get_state(self):
return self.data.get_state()
def set_state(self, state):
self.data.set_state(state)
@classmethod
def from_state(cls, state):
return cls(**state)
@property @property
def headers(self): def headers(self):
""" """

View File

@ -16,9 +16,8 @@ from .message import Message, _native, _always_bytes, MessageData
class RequestData(MessageData): class RequestData(MessageData):
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None,
timestamp_start=None, timestamp_end=None): timestamp_start=None, timestamp_end=None):
if not headers: if not isinstance(headers, Headers):
headers = Headers() headers = Headers(headers)
assert isinstance(headers, Headers)
self.first_line_format = first_line_format self.first_line_format = first_line_format
self.method = method self.method = method

View File

@ -12,9 +12,8 @@ from ..odict import ODict
class ResponseData(MessageData): class ResponseData(MessageData):
def __init__(self, http_version, status_code, reason=None, headers=None, content=None, def __init__(self, http_version, status_code, reason=None, headers=None, content=None,
timestamp_start=None, timestamp_end=None): timestamp_start=None, timestamp_end=None):
if not headers: if not isinstance(headers, Headers):
headers = Headers() headers = Headers(headers)
assert isinstance(headers, Headers)
self.http_version = http_version self.http_version = http_version
self.status_code = status_code self.status_code = status_code

View File

@ -3,6 +3,8 @@ import re
import copy import copy
import six import six
from .utils import Serializable
def safe_subn(pattern, repl, target, *args, **kwargs): def safe_subn(pattern, repl, target, *args, **kwargs):
""" """
@ -13,7 +15,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
return re.subn(str(pattern), str(repl), target, *args, **kwargs) return re.subn(str(pattern), str(repl), target, *args, **kwargs)
class ODict(object): class ODict(Serializable):
""" """
A dictionary-like object for managing ordered (key, value) data. Think A dictionary-like object for managing ordered (key, value) data. Think
@ -172,12 +174,12 @@ class ODict(object):
def get_state(self): def get_state(self):
return [tuple(i) for i in self.lst] return [tuple(i) for i in self.lst]
def load_state(self, state): def set_state(self, state):
self.lst = [list(i) for i in state] self.lst = [list(i) for i in state]
@classmethod @classmethod
def from_state(klass, state): def from_state(cls, state):
return klass([list(i) for i in state]) return cls([list(i) for i in state])
class ODictCaseless(ODict): class ODictCaseless(ODict):

View File

@ -16,7 +16,7 @@ import six
import OpenSSL import OpenSSL
from OpenSSL import SSL from OpenSSL import SSL
from . import certutils, version_check from . import certutils, version_check, utils
# This is a rather hackish way to make sure that # This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed. # the latest version of pyOpenSSL is actually installed.
@ -298,7 +298,7 @@ class Reader(_FileLike):
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
class Address(object): class Address(utils.Serializable):
""" """
This class wraps an IPv4/IPv6 tuple to provide named attributes and This class wraps an IPv4/IPv6 tuple to provide named attributes and
@ -309,6 +309,20 @@ class Address(object):
self.address = tuple(address) self.address = tuple(address)
self.use_ipv6 = use_ipv6 self.use_ipv6 = use_ipv6
def get_state(self):
return {
"address": self.address,
"use_ipv6": self.use_ipv6
}
def set_state(self, state):
self.address = state["address"]
self.use_ipv6 = state["use_ipv6"]
@classmethod
def from_state(cls, state):
return Address(**state)
@classmethod @classmethod
def wrap(cls, t): def wrap(cls, t):
if isinstance(t, cls): if isinstance(t, cls):

View File

@ -1,14 +1,45 @@
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import os.path import os.path
import re import re
import string
import codecs import codecs
import unicodedata import unicodedata
from abc import ABCMeta, abstractmethod
import six import six
from six.moves import urllib from six.moves import urllib
import hyperframe import hyperframe
@six.add_metaclass(ABCMeta)
class Serializable(object):
"""
Abstract Base Class that defines an API to save an object's state and restore it later on.
"""
@classmethod
@abstractmethod
def from_state(cls, state):
"""
Create a new object from the given state.
"""
raise NotImplementedError()
@abstractmethod
def get_state(self):
"""
Retrieve object state.
"""
raise NotImplementedError()
@abstractmethod
def set_state(self, state):
"""
Set object state to the given state.
"""
raise NotImplementedError()
def always_bytes(unicode_or_bytes, *encode_args): def always_bytes(unicode_or_bytes, *encode_args):
if isinstance(unicode_or_bytes, six.text_type): if isinstance(unicode_or_bytes, six.text_type):
return unicode_or_bytes.encode(*encode_args) return unicode_or_bytes.encode(*encode_args)

View File

@ -148,5 +148,5 @@ class TestHeaders(object):
headers2 = Headers() headers2 = Headers()
assert headers != headers2 assert headers != headers2
headers2.load_state(headers.get_state()) headers2.set_state(headers.get_state())
assert headers == headers2 assert headers == headers2

View File

@ -12,7 +12,7 @@ from .test_message import _test_decoded_attr, _test_passthrough_attr
class TestRequestData(object): class TestRequestData(object):
def test_init(self): def test_init(self):
with raises(AssertionError): with raises(ValueError if six.PY2 else TypeError):
treq(headers="foobar") treq(headers="foobar")
assert isinstance(treq(headers=None).headers, Headers) assert isinstance(treq(headers=None).headers, Headers)

View File

@ -1,5 +1,7 @@
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import six
from netlib.http import Headers from netlib.http import Headers
from netlib.odict import ODict, ODictCaseless from netlib.odict import ODict, ODictCaseless
from netlib.tutils import raises, tresp from netlib.tutils import raises, tresp
@ -8,7 +10,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr
class TestResponseData(object): class TestResponseData(object):
def test_init(self): def test_init(self):
with raises(AssertionError): with raises(ValueError if six.PY2 else TypeError):
tresp(headers="foobar") tresp(headers="foobar")
assert isinstance(tresp(headers=None).headers, Headers) assert isinstance(tresp(headers=None).headers, Headers)

View File

@ -24,7 +24,7 @@ class TestODict(object):
nd = odict.ODict.from_state(state) nd = odict.ODict.from_state(state)
assert nd == od assert nd == od
b = odict.ODict() b = odict.ODict()
b.load_state(state) b.set_state(state)
assert b == od assert b == od
def test_in_any(self): def test_in_any(self):