make stateobject simpler and stricter

This commit is contained in:
Maximilian Hils 2016-02-08 04:19:25 +01:00
parent e9934cc008
commit bdb763d9cf
8 changed files with 87 additions and 111 deletions

View File

@ -29,6 +29,7 @@ def convert_015_016(data):
data[m]["http_version"] = data[m].pop("httpversion") data[m]["http_version"] = data[m].pop("httpversion")
if "msg" in data["response"]: if "msg" in data["response"]:
data["response"]["reason"] = data["response"].pop("msg") data["response"]["reason"] = data["response"].pop("msg")
data["request"].pop("form_out", None)
data["version"] = (0, 16) data["version"] = (0, 16)
return data return data

View File

@ -42,28 +42,14 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
return self.ssl_established return self.ssl_established
_stateobject_attributes = dict( _stateobject_attributes = dict(
address=tcp.Address,
clientcert=certutils.SSLCert,
ssl_established=bool, ssl_established=bool,
timestamp_start=float, timestamp_start=float,
timestamp_end=float, timestamp_end=float,
timestamp_ssl_setup=float timestamp_ssl_setup=float
) )
def get_state(self):
d = super(ClientConnection, self).get_state()
d.update(
address=({
"address": self.address(),
"use_ipv6": self.address.use_ipv6} if self.address else {}),
clientcert=self.cert.to_pem() if self.clientcert else None)
return d
def load_state(self, state):
super(ClientConnection, self).load_state(state)
self.address = tcp.Address(
**state["address"]) if state["address"] else None
self.clientcert = certutils.SSLCert.from_pem(
state["clientcert"]) if state["clientcert"] else None
def copy(self): def copy(self):
return copy.copy(self) return copy.copy(self)
@ -76,7 +62,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
f = cls(None, tuple(), None) f = cls(None, tuple(), None)
f.load_state(state) f.set_state(state)
return f return f
def convert_to_ssl(self, *args, **kwargs): def convert_to_ssl(self, *args, **kwargs):
@ -131,31 +117,10 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
sni=str sni=str
) )
def get_state(self):
d = super(ServerConnection, self).get_state()
d.update(
address=({"address": self.address(),
"use_ipv6": self.address.use_ipv6} if self.address else {}),
source_address=({"address": self.source_address(),
"use_ipv6": self.source_address.use_ipv6} if self.source_address else None),
cert=self.cert.to_pem() if self.cert else None
)
return d
def load_state(self, state):
super(ServerConnection, self).load_state(state)
self.address = tcp.Address(
**state["address"]) if state["address"] else None
self.source_address = tcp.Address(
**state["source_address"]) if state["source_address"] else None
self.cert = certutils.SSLCert.from_pem(
state["cert"]) if state["cert"] else None
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
f = cls(tuple()) f = cls(tuple())
f.load_state(state) f.set_state(state)
return f return f
def copy(self): def copy(self):

View File

@ -45,7 +45,7 @@ class Error(stateobject.StateObject):
# the default implementation assumes an empty constructor. Override # the default implementation assumes an empty constructor. Override
# accordingly. # accordingly.
f = cls(None) f = cls(None)
f.load_state(state) f.set_state(state)
return f return f
def copy(self): def copy(self):
@ -93,6 +93,12 @@ class Flow(stateobject.StateObject):
d.update(backup=self._backup) d.update(backup=self._backup)
return d return d
def set_state(self, state):
state.pop("version")
if "backup" in state:
self._backup = state.pop("backup")
super(Flow, self).set_state(state)
def __eq__(self, other): def __eq__(self, other):
return self is other return self is other
@ -130,7 +136,7 @@ class Flow(stateobject.StateObject):
Revert to the last backed up state. Revert to the last backed up state.
""" """
if self._backup: if self._backup:
self.load_state(self._backup) self.set_state(self._backup)
self._backup = None self._backup = None
def kill(self, master): def kill(self, master):

View File

@ -1,6 +1,7 @@
from __future__ import (absolute_import, print_function, division) from __future__ import (absolute_import, print_function, division)
import Cookie import Cookie
import copy import copy
import warnings
from email.utils import parsedate_tz, formatdate, mktime_tz from email.utils import parsedate_tz, formatdate, mktime_tz
import time import time
@ -8,28 +9,12 @@ from libmproxy import utils
from netlib import encoding from netlib import encoding
from netlib.http import status_codes, Headers, Request, Response, decoded from netlib.http import status_codes, Headers, Request, Response, decoded
from netlib.tcp import Address from netlib.tcp import Address
from .. import version, stateobject from .. import version
from .flow import Flow from .flow import Flow
class MessageMixin(stateobject.StateObject): class MessageMixin(object):
def get_state(self):
state = vars(self.data).copy()
state["headers"] = state["headers"].get_state()
return state
def load_state(self, state):
for k, v in state.items():
if k == "headers":
v = Headers.from_state(v)
setattr(self.data, k, v)
@classmethod
def from_state(cls, state):
state["headers"] = Headers.from_state(state["headers"])
return cls(**state)
def get_decoded_content(self): def get_decoded_content(self):
""" """
@ -136,6 +121,8 @@ class HTTPRequest(MessageMixin, Request):
timestamp_end=None, timestamp_end=None,
form_out=None, form_out=None,
is_replay=False, is_replay=False,
stickycookie=False,
stickyauth=False,
): ):
Request.__init__( Request.__init__(
self, self,
@ -154,21 +141,26 @@ class HTTPRequest(MessageMixin, Request):
self.form_out = form_out or first_line_format # FIXME remove self.form_out = form_out or first_line_format # FIXME remove
# Have this request's cookies been modified by sticky cookies or auth? # Have this request's cookies been modified by sticky cookies or auth?
self.stickycookie = False self.stickycookie = stickycookie
self.stickyauth = False self.stickyauth = stickyauth
# Is this request replayed? # Is this request replayed?
self.is_replay = is_replay self.is_replay = is_replay
@classmethod def get_state(self):
def from_protocol( state = super(HTTPRequest, self).get_state()
self, state.update(
protocol, stickycookie = self.stickycookie,
*args, stickyauth = self.stickyauth,
**kwargs is_replay = self.is_replay,
): )
req = protocol.read_request(*args, **kwargs) return state
return self.wrap(req)
def set_state(self, state):
self.stickycookie = state.pop("stickycookie")
self.stickyauth = state.pop("stickyauth")
self.is_replay = state.pop("is_replay")
super(HTTPRequest, self).set_state(state)
@classmethod @classmethod
def wrap(self, request): def wrap(self, request):
@ -188,6 +180,15 @@ class HTTPRequest(MessageMixin, Request):
) )
return req return req
@property
def form_out(self):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
return self.first_line_format
@form_out.setter
def form_out(self, value):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
def __hash__(self): def __hash__(self):
return id(self) return id(self)
@ -257,16 +258,6 @@ class HTTPResponse(MessageMixin, Response):
self.is_replay = is_replay self.is_replay = is_replay
self.stream = False self.stream = False
@classmethod
def from_protocol(
self,
protocol,
*args,
**kwargs
):
resp = protocol.read_response(*args, **kwargs)
return self.wrap(resp)
@classmethod @classmethod
def wrap(self, response): def wrap(self, response):
resp = HTTPResponse( resp = HTTPResponse(
@ -377,7 +368,7 @@ class HTTPFlow(Flow):
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
f = cls(None, None) f = cls(None, None)
f.load_state(state) f.set_state(state)
return f return f
def __repr__(self): def __repr__(self):

View File

@ -1,26 +1,25 @@
from __future__ import absolute_import from __future__ import absolute_import
from netlib.utils import Serializable
class StateObject(object): class StateObject(Serializable):
""" """
An object with serializable state. An object with serializable state.
State attributes can either be serializable types(str, tuple, bool, ...) State attributes can either be serializable types(str, tuple, bool, ...)
or StateObject instances themselves. or StateObject instances themselves.
""" """
# An attribute-name -> class-or-type dict containing all attributes that
# should be serialized. If the attribute is a class, it must implement the
# StateObject protocol.
_stateobject_attributes = None _stateobject_attributes = None
"""
def from_state(self, state): An attribute-name -> class-or-type dict containing all attributes that
raise NotImplementedError() should be serialized. If the attribute is a class, it must implement the
Serializable protocol.
"""
def get_state(self): def get_state(self):
""" """
Retrieve object state. If short is true, return an abbreviated Retrieve object state.
format with long data elided.
""" """
state = {} state = {}
for attr, cls in self._stateobject_attributes.iteritems(): for attr, cls in self._stateobject_attributes.iteritems():
@ -31,18 +30,22 @@ class StateObject(object):
state[attr] = val state[attr] = val
return state return state
def load_state(self, state): def set_state(self, state):
""" """
Load object state from data returned by a get_state call. Load object state from data returned by a get_state call.
""" """
state = state.copy()
for attr, cls in self._stateobject_attributes.iteritems(): for attr, cls in self._stateobject_attributes.iteritems():
if state.get(attr, None) is None: if state.get(attr) is None:
setattr(self, attr, None) setattr(self, attr, state.pop(attr))
else: else:
curr = getattr(self, attr) curr = getattr(self, attr)
if hasattr(curr, "load_state"): if hasattr(curr, "set_state"):
curr.load_state(state[attr]) curr.set_state(state.pop(attr))
elif hasattr(cls, "from_state"): elif hasattr(cls, "from_state"):
setattr(self, attr, cls.from_state(state[attr])) obj = cls.from_state(state.pop(attr))
else: setattr(self, attr, obj)
setattr(self, attr, cls(state[attr])) else: # primitive types such as int, str, ...
setattr(self, attr, cls(state.pop(attr)))
if state:
raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state))

View File

@ -170,7 +170,7 @@ class FlowHandler(RequestHandler):
elif k == "port": elif k == "port":
request.port = int(v) request.port = int(v)
elif k == "headers": elif k == "headers":
request.headers.load_state(v) request.headers.set_state(v)
else: else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v) print "Warning: Unknown update {}.{}: {}".format(a, k, v)
@ -184,7 +184,7 @@ class FlowHandler(RequestHandler):
elif k == "http_version": elif k == "http_version":
response.http_version = str(v) response.http_version = str(v)
elif k == "headers": elif k == "headers":
response.headers.load_state(v) response.headers.set_state(v)
else: else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v) print "Warning: Unknown update {}.{}: {}".format(a, k, v)
else: else:

View File

@ -422,7 +422,7 @@ class TestFlow(object):
assert not f == f2 assert not f == f2
f2.error = Error("e2") f2.error = Error("e2")
assert not f == f2 assert not f == f2
f.load_state(f2.get_state()) f.set_state(f2.get_state())
assert f.get_state() == f2.get_state() assert f.get_state() == f2.get_state()
def test_kill(self): def test_kill(self):
@ -1204,7 +1204,7 @@ class TestError:
e2 = Error("bar") e2 = Error("bar")
assert not e == e2 assert not e == e2
e.load_state(e2.get_state()) e.set_state(e2.get_state())
assert e.get_state() == e2.get_state() assert e.get_state() == e2.get_state()
e3 = e.copy() e3 = e.copy()
@ -1224,7 +1224,7 @@ class TestClientConnection:
assert not c == c2 assert not c == c2
c2.timestamp_start = 42 c2.timestamp_start = 42
c.load_state(c2.get_state()) c.set_state(c2.get_state())
assert c.timestamp_start == 42 assert c.timestamp_start == 42
c3 = c.copy() c3 = c.copy()

View File

@ -76,7 +76,11 @@ def tclient_conn():
""" """
c = ClientConnection.from_state(dict( c = ClientConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True), address=dict(address=("address", 22), use_ipv6=True),
clientcert=None clientcert=None,
ssl_established=False,
timestamp_start=1,
timestamp_ssl_setup=2,
timestamp_end=3,
)) ))
c.reply = controller.DummyReply() c.reply = controller.DummyReply()
return c return c
@ -88,9 +92,15 @@ def tserver_conn():
""" """
c = ServerConnection.from_state(dict( c = ServerConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True), address=dict(address=("address", 22), use_ipv6=True),
state=[],
source_address=dict(address=("address", 22), use_ipv6=True), source_address=dict(address=("address", 22), use_ipv6=True),
cert=None cert=None,
timestamp_start=1,
timestamp_tcp_setup=2,
timestamp_ssl_setup=3,
timestamp_end=4,
ssl_established=False,
sni="address",
via=None
)) ))
c.reply = controller.DummyReply() c.reply = controller.DummyReply()
return c return c