move StateObject back into libmproxy

This commit is contained in:
Maximilian Hils 2014-01-31 01:06:35 +01:00
parent 5fce7be592
commit 6ce1470631
4 changed files with 112 additions and 9 deletions

View File

@ -26,7 +26,7 @@ class ProtocolHandler(object):
This method gets called should there be an uncaught exception during the connection. This method gets called should there be an uncaught exception during the connection.
This might happen outside of handle_messages, e.g. if the initial SSL handshake fails in transparent mode. This might happen outside of handle_messages, e.g. if the initial SSL handshake fails in transparent mode.
""" """
raise NotImplementedError raise error
from . import http, tcp from . import http, tcp

View File

@ -1,10 +1,10 @@
import Cookie, urllib, urlparse, time, copy import Cookie, urllib, urlparse, time, copy
from email.utils import parsedate_tz, formatdate, mktime_tz from email.utils import parsedate_tz, formatdate, mktime_tz
import netlib.utils import netlib.utils
from netlib import http, tcp, http_status, stateobject, odict from netlib import http, tcp, http_status, odict
from netlib.odict import ODict, ODictCaseless from netlib.odict import ODict, ODictCaseless
from . import ProtocolHandler, ConnectionTypeChange, KILL from . import ProtocolHandler, ConnectionTypeChange, KILL
from .. import encoding, utils, version, filt, controller from .. import encoding, utils, version, filt, controller, stateobject
from ..proxy import ProxyError, ClientConnection, ServerConnection from ..proxy import ProxyError, ClientConnection, ServerConnection
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"

View File

@ -1,7 +1,7 @@
import os, socket, time, threading import os, socket, time, threading
from OpenSSL import SSL from OpenSSL import SSL
from netlib import tcp, http, certutils, http_auth, stateobject from netlib import tcp, http, certutils, http_auth
import utils, version, platform, controller import utils, version, platform, controller, stateobject
TRANSPARENT_SSL_PORTS = [443, 8443] TRANSPARENT_SSL_PORTS = [443, 8443]
@ -36,7 +36,7 @@ class ProxyConfig:
class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
def __init__(self, client_connection, address, server): def __init__(self, client_connection, address, server):
if client_connection: # Eventually, this object is restored from state if client_connection: # Eventually, this object is restored from state. We don't have a connection then.
tcp.BaseHandler.__init__(self, client_connection, address, server) tcp.BaseHandler.__init__(self, client_connection, address, server)
else: else:
self.address = None self.address = None
@ -49,11 +49,22 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
_stateobject_attributes = dict( _stateobject_attributes = dict(
timestamp_start=float, timestamp_start=float,
timestamp_end=float, timestamp_end=float,
timestamp_ssl_setup=float, timestamp_ssl_setup=float
address=tcp.Address,
clientcert=certutils.SSLCert
) )
def _get_state(self):
d = super(ClientConnection, self)._get_state()
d.update(
address={"address": self.address(), "use_ipv6": self.address.use_ipv6},
clientcert=self.cert.to_pem() if self.clientcert else None
)
return d
def _load_state(self, state):
super(ClientConnection, self)._load_state(state)
self.address = tcp.Address(**state["address"]) if state["address"] else None
self.clientcert = certutils.SSLCert.from_pem(state["clientcert"]) if state["clientcert"] else None
@classmethod @classmethod
def _from_state(cls, state): def _from_state(cls, state):
f = cls(None, None, None) f = cls(None, None, None)
@ -90,6 +101,23 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
cert=certutils.SSLCert cert=certutils.SSLCert
) )
def _get_state(self):
d = super(ServerConnection, self)._get_state()
d.update(
address={"address": self.address(), "use_ipv6": self.address.use_ipv6},
source_address= {"address": self.source_address(),
"use_ipv6": self.source_address.use_ipv6} if self.source_address else None,
cert=self.cert.to_pem() if self.cert else None
)
return d
def _load_state(self, state):
super(ServerConnection, self)._load_state(state)
self.address = tcp.Address(**state["address"]) if state["address"] else None
self.source_address = tcp.Address(**state["source_address"]) if state["source_address"] else None
self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None
@classmethod @classmethod
def _from_state(cls, state): def _from_state(cls, state):
f = cls(None) f = cls(None)

75
libmproxy/stateobject.py Normal file
View File

@ -0,0 +1,75 @@
class StateObject:
def _get_state(self):
raise NotImplementedError
def _load_state(self, state):
raise NotImplementedError
@classmethod
def _from_state(cls, state):
raise NotImplementedError
def __eq__(self, other):
try:
return self._get_state() == other._get_state()
except AttributeError: # we may compare with something that's not a StateObject
return False
class SimpleStateObject(StateObject):
"""
A StateObject with opionated conventions that tries to keep everything DRY.
Simply put, you agree on a list of attributes and their type.
Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves.
SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods.
Overriding _get_state or _load_state to add custom adjustments is always possible.
"""
_stateobject_attributes = None # none by default to raise an exception if definition was forgotten
"""
An attribute-name -> class-or-type dict containing all attributes that should be serialized
If the attribute is a class, this class must be a subclass of StateObject.
"""
def _get_state(self):
return {attr: self._get_state_attr(attr, cls)
for attr, cls in self._stateobject_attributes.iteritems()}
def _get_state_attr(self, attr, cls):
"""
helper for _get_state.
returns the value of the given attribute
"""
val = getattr(self, attr)
if hasattr(val, "_get_state"):
return val._get_state()
else:
return val
def _load_state(self, state):
for attr, cls in self._stateobject_attributes.iteritems():
self._load_state_attr(attr, cls, state)
def _load_state_attr(self, attr, cls, state):
"""
helper for _load_state.
loads the given attribute from the state.
"""
if state[attr] is None:
setattr(self, attr, None)
return
curr = getattr(self, attr)
if hasattr(curr, "_load_state"):
curr._load_state(state[attr])
elif hasattr(cls, "_from_state"):
setattr(self, attr, cls._from_state(state[attr]))
else:
setattr(self, attr, cls(state[attr]))
@classmethod
def _from_state(cls, state):
f = cls() # the default implementation assumes an empty constructor. Override accordingly.
f._load_state(state)
return f