make stateobject simpler and stricter

This commit is contained in:
Maximilian Hils 2016-02-08 04:19:25 +01:00
parent e9934cc008
commit bdb763d9cf
8 changed files with 87 additions and 111 deletions

View File

@ -29,6 +29,7 @@ def convert_015_016(data):
data[m]["http_version"] = data[m].pop("httpversion")
if "msg" in data["response"]:
data["response"]["reason"] = data["response"].pop("msg")
data["request"].pop("form_out", None)
data["version"] = (0, 16)
return data

View File

@ -42,28 +42,14 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
return self.ssl_established
_stateobject_attributes = dict(
address=tcp.Address,
clientcert=certutils.SSLCert,
ssl_established=bool,
timestamp_start=float,
timestamp_end=float,
timestamp_ssl_setup=float
)
def get_state(self):
d = super(ClientConnection, self).get_state()
d.update(
address=({
"address": self.address(),
"use_ipv6": self.address.use_ipv6} if self.address else {}),
clientcert=self.cert.to_pem() if self.clientcert else None)
return d
def load_state(self, state):
super(ClientConnection, self).load_state(state)
self.address = tcp.Address(
**state["address"]) if state["address"] else None
self.clientcert = certutils.SSLCert.from_pem(
state["clientcert"]) if state["clientcert"] else None
def copy(self):
return copy.copy(self)
@ -76,7 +62,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod
def from_state(cls, state):
f = cls(None, tuple(), None)
f.load_state(state)
f.set_state(state)
return f
def convert_to_ssl(self, *args, **kwargs):
@ -131,31 +117,10 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
sni=str
)
def get_state(self):
d = super(ServerConnection, self).get_state()
d.update(
address=({"address": self.address(),
"use_ipv6": self.address.use_ipv6} if self.address else {}),
source_address=({"address": self.source_address(),
"use_ipv6": self.source_address.use_ipv6} if self.source_address else None),
cert=self.cert.to_pem() if self.cert else None
)
return d
def load_state(self, state):
super(ServerConnection, self).load_state(state)
self.address = tcp.Address(
**state["address"]) if state["address"] else None
self.source_address = tcp.Address(
**state["source_address"]) if state["source_address"] else None
self.cert = certutils.SSLCert.from_pem(
state["cert"]) if state["cert"] else None
@classmethod
def from_state(cls, state):
f = cls(tuple())
f.load_state(state)
f.set_state(state)
return f
def copy(self):

View File

@ -45,7 +45,7 @@ class Error(stateobject.StateObject):
# the default implementation assumes an empty constructor. Override
# accordingly.
f = cls(None)
f.load_state(state)
f.set_state(state)
return f
def copy(self):
@ -93,6 +93,12 @@ class Flow(stateobject.StateObject):
d.update(backup=self._backup)
return d
def set_state(self, state):
state.pop("version")
if "backup" in state:
self._backup = state.pop("backup")
super(Flow, self).set_state(state)
def __eq__(self, other):
return self is other
@ -130,7 +136,7 @@ class Flow(stateobject.StateObject):
Revert to the last backed up state.
"""
if self._backup:
self.load_state(self._backup)
self.set_state(self._backup)
self._backup = None
def kill(self, master):

View File

@ -1,6 +1,7 @@
from __future__ import (absolute_import, print_function, division)
import Cookie
import copy
import warnings
from email.utils import parsedate_tz, formatdate, mktime_tz
import time
@ -8,28 +9,12 @@ from libmproxy import utils
from netlib import encoding
from netlib.http import status_codes, Headers, Request, Response, decoded
from netlib.tcp import Address
from .. import version, stateobject
from .. import version
from .flow import Flow
class MessageMixin(stateobject.StateObject):
def get_state(self):
state = vars(self.data).copy()
state["headers"] = state["headers"].get_state()
return state
def load_state(self, state):
for k, v in state.items():
if k == "headers":
v = Headers.from_state(v)
setattr(self.data, k, v)
@classmethod
def from_state(cls, state):
state["headers"] = Headers.from_state(state["headers"])
return cls(**state)
class MessageMixin(object):
def get_decoded_content(self):
"""
@ -136,6 +121,8 @@ class HTTPRequest(MessageMixin, Request):
timestamp_end=None,
form_out=None,
is_replay=False,
stickycookie=False,
stickyauth=False,
):
Request.__init__(
self,
@ -154,21 +141,26 @@ class HTTPRequest(MessageMixin, Request):
self.form_out = form_out or first_line_format # FIXME remove
# Have this request's cookies been modified by sticky cookies or auth?
self.stickycookie = False
self.stickyauth = False
self.stickycookie = stickycookie
self.stickyauth = stickyauth
# Is this request replayed?
self.is_replay = is_replay
@classmethod
def from_protocol(
self,
protocol,
*args,
**kwargs
):
req = protocol.read_request(*args, **kwargs)
return self.wrap(req)
def get_state(self):
state = super(HTTPRequest, self).get_state()
state.update(
stickycookie = self.stickycookie,
stickyauth = self.stickyauth,
is_replay = self.is_replay,
)
return state
def set_state(self, state):
self.stickycookie = state.pop("stickycookie")
self.stickyauth = state.pop("stickyauth")
self.is_replay = state.pop("is_replay")
super(HTTPRequest, self).set_state(state)
@classmethod
def wrap(self, request):
@ -188,6 +180,15 @@ class HTTPRequest(MessageMixin, Request):
)
return req
@property
def form_out(self):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
return self.first_line_format
@form_out.setter
def form_out(self, value):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
def __hash__(self):
return id(self)
@ -257,16 +258,6 @@ class HTTPResponse(MessageMixin, Response):
self.is_replay = is_replay
self.stream = False
@classmethod
def from_protocol(
self,
protocol,
*args,
**kwargs
):
resp = protocol.read_response(*args, **kwargs)
return self.wrap(resp)
@classmethod
def wrap(self, response):
resp = HTTPResponse(
@ -377,7 +368,7 @@ class HTTPFlow(Flow):
@classmethod
def from_state(cls, state):
f = cls(None, None)
f.load_state(state)
f.set_state(state)
return f
def __repr__(self):

View File

@ -1,26 +1,25 @@
from __future__ import absolute_import
from netlib.utils import Serializable
class StateObject(object):
class StateObject(Serializable):
"""
An object with serializable state.
An object with serializable state.
State attributes can either be serializable types(str, tuple, bool, ...)
or StateObject instances themselves.
State attributes can either be serializable types(str, tuple, bool, ...)
or StateObject instances themselves.
"""
# An attribute-name -> class-or-type dict containing all attributes that
# should be serialized. If the attribute is a class, it must implement the
# StateObject protocol.
_stateobject_attributes = None
def from_state(self, state):
raise NotImplementedError()
"""
An attribute-name -> class-or-type dict containing all attributes that
should be serialized. If the attribute is a class, it must implement the
Serializable protocol.
"""
def get_state(self):
"""
Retrieve object state. If short is true, return an abbreviated
format with long data elided.
Retrieve object state.
"""
state = {}
for attr, cls in self._stateobject_attributes.iteritems():
@ -31,18 +30,22 @@ class StateObject(object):
state[attr] = val
return state
def load_state(self, state):
def set_state(self, state):
"""
Load object state from data returned by a get_state call.
Load object state from data returned by a get_state call.
"""
state = state.copy()
for attr, cls in self._stateobject_attributes.iteritems():
if state.get(attr, None) is None:
setattr(self, attr, None)
if state.get(attr) is None:
setattr(self, attr, state.pop(attr))
else:
curr = getattr(self, attr)
if hasattr(curr, "load_state"):
curr.load_state(state[attr])
if hasattr(curr, "set_state"):
curr.set_state(state.pop(attr))
elif hasattr(cls, "from_state"):
setattr(self, attr, cls.from_state(state[attr]))
else:
setattr(self, attr, cls(state[attr]))
obj = cls.from_state(state.pop(attr))
setattr(self, attr, obj)
else: # primitive types such as int, str, ...
setattr(self, attr, cls(state.pop(attr)))
if state:
raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state))

View File

@ -170,7 +170,7 @@ class FlowHandler(RequestHandler):
elif k == "port":
request.port = int(v)
elif k == "headers":
request.headers.load_state(v)
request.headers.set_state(v)
else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v)
@ -184,7 +184,7 @@ class FlowHandler(RequestHandler):
elif k == "http_version":
response.http_version = str(v)
elif k == "headers":
response.headers.load_state(v)
response.headers.set_state(v)
else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v)
else:

View File

@ -422,7 +422,7 @@ class TestFlow(object):
assert not f == f2
f2.error = Error("e2")
assert not f == f2
f.load_state(f2.get_state())
f.set_state(f2.get_state())
assert f.get_state() == f2.get_state()
def test_kill(self):
@ -1204,7 +1204,7 @@ class TestError:
e2 = Error("bar")
assert not e == e2
e.load_state(e2.get_state())
e.set_state(e2.get_state())
assert e.get_state() == e2.get_state()
e3 = e.copy()
@ -1224,7 +1224,7 @@ class TestClientConnection:
assert not c == c2
c2.timestamp_start = 42
c.load_state(c2.get_state())
c.set_state(c2.get_state())
assert c.timestamp_start == 42
c3 = c.copy()

View File

@ -76,7 +76,11 @@ def tclient_conn():
"""
c = ClientConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True),
clientcert=None
clientcert=None,
ssl_established=False,
timestamp_start=1,
timestamp_ssl_setup=2,
timestamp_end=3,
))
c.reply = controller.DummyReply()
return c
@ -88,9 +92,15 @@ def tserver_conn():
"""
c = ServerConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True),
state=[],
source_address=dict(address=("address", 22), use_ipv6=True),
cert=None
cert=None,
timestamp_start=1,
timestamp_tcp_setup=2,
timestamp_ssl_setup=3,
timestamp_end=4,
ssl_established=False,
sni="address",
via=None
))
c.reply = controller.DummyReply()
return c